wf_sweep.f90 Source File


Source Code

module wf_sweep
  !! Sweep management: configuration builder, sweep registration, and agents.
  use iso_c_binding, only: c_int, c_double, c_char, c_null_char, c_float
  use wf_c_bindings
  implicit none
  private

  public :: wandb_sweep_config_type
  public :: wandb_sweep
  public :: wandb_agent
  public :: wandb_sweep_start_agent
  public :: wandb_sweep_next_params
  public :: wandb_sweep_run_done


  !-----------------------------------------------------------------------------
  ! wandb_sweep_config_type — programmatic sweep configuration builder
  !-----------------------------------------------------------------------------
  !! Builder for a wandb sweep configuration.
  !!
  !! Build the configuration with the type-bound procedures, then pass
  !! it directly to wandb_sweep or call to_json() for the raw string.
  !!
  !! ## Example
  !!
  !! ```fortran
  !! type(wandb_sweep_config_type) :: cfg
  !!
  !! call cfg%set_method("bayes")
  !! call cfg%set_metric("val_loss", "minimize")
  !!
  !! call cfg%add_param_range("learning_rate", 1e-5_real64, 1e-2_real64, &
  !!                           distribution="log_uniform_values")
  !! call cfg%add_param_values("hidden_size", [32, 64, 128, 256])
  !! call cfg%add_param_values("activation", ["relu   ", "tanh   ", "sigmoid"])
  !!
  !! call wandb_sweep(cfg, project="my-project", sweep_id=id)
  !! ```
  type :: wandb_sweep_config_type
    character(len=:), allocatable :: method_str
    character(len=:), allocatable :: metric_name_str
    character(len=:), allocatable :: metric_goal_str
    character(len=:), allocatable :: params_buf  ! accumulated parameter JSON
    integer :: num_params = 0
  contains
    procedure :: set_method        => swcfg_set_method
    procedure :: set_metric        => swcfg_set_metric
    procedure, private :: add_range_r32   => swcfg_add_range_r32
    procedure, private :: add_range_r64   => swcfg_add_range_r64
    procedure, private :: add_vals_int    => swcfg_add_values_int
    procedure, private :: add_vals_r32    => swcfg_add_values_r32
    procedure, private :: add_vals_r64    => swcfg_add_values_r64
    procedure, private :: add_vals_str    => swcfg_add_values_str
    generic :: add_param_range  => add_range_r32, add_range_r64
    generic :: add_param_values => add_vals_int, add_vals_r32, &
                                    add_vals_r64, add_vals_str
    procedure :: to_json => swcfg_to_json
  end type wandb_sweep_config_type


  !! Register a sweep — accepts either a raw JSON string or a
  !! wandb_sweep_config_type built with the helper API.
  interface wandb_sweep
    module procedure wandb_sweep_json_str
    module procedure wandb_sweep_cfg_obj
  end interface wandb_sweep

contains

  !-----------------------------------------------------------------------------
  ! wandb_sweep  (raw JSON string variant)
  !-----------------------------------------------------------------------------
  subroutine wandb_sweep_json_str(config_json, project, sweep_id, entity)
    !! Register a hyperparameter sweep with wandb.
    !!
    !! @param config_json  JSON string describing the sweep (method, metric,
    !!                     parameters).
    !! @param project      wandb project name.
    !! @param sweep_id     Output: the sweep ID string assigned by wandb.
    !! @param entity       Optional wandb entity/team.
    character(len=*),           intent(in)  :: config_json
    character(len=*),           intent(in)  :: project
    character(len=*),           intent(out) :: sweep_id
    character(len=*), optional, intent(in)  :: entity

    integer, parameter :: BUF_LEN = 256
    character(kind=c_char) :: id_buf(BUF_LEN)
    character(len=:), allocatable :: c_entity
    integer(c_int) :: rc
    integer :: i

    if(present(entity))then
       c_entity = entity // c_null_char
    else
       c_entity = c_null_char
    end if

    id_buf = c_null_char
    rc = wandb_sweep_c( &
         config_json // c_null_char, &
         project     // c_null_char, &
         c_entity,                   &
         id_buf,                     &
         int(BUF_LEN, c_int)         &
    )

    if(rc /= 0)then
       write(0,*) "[wf] WARNING: wandb_sweep failed (rc=", rc, ")"
       sweep_id = ' '
       return
    end if

    sweep_id = ' '
    do i = 1, len(sweep_id)
       if(id_buf(i) == c_null_char) exit
       sweep_id(i:i) = id_buf(i)
    end do

  end subroutine wandb_sweep_json_str


  !-----------------------------------------------------------------------------
  ! wandb_sweep  (config-type variant)
  !-----------------------------------------------------------------------------
  subroutine wandb_sweep_cfg_obj(config, project, sweep_id, entity)
    !! Register a sweep from a wandb_sweep_config_type builder object.
    type(wandb_sweep_config_type), intent(in)  :: config
    character(len=*),              intent(in)  :: project
    character(len=*),              intent(out) :: sweep_id
    character(len=*), optional,    intent(in)  :: entity

    character(len=:), allocatable :: json

    json = config%to_json()
    if(present(entity))then
       call wandb_sweep_json_str(json, project, sweep_id, entity)
    else
       call wandb_sweep_json_str(json, project, sweep_id)
    end if

  end subroutine wandb_sweep_cfg_obj


  !-----------------------------------------------------------------------------
  ! wandb_agent
  !-----------------------------------------------------------------------------
  subroutine wandb_agent(sweep_id, project, count, entity)
    !! Run a wandb sweep agent.
    !!
    !! Calls wandb.agent(sweep_id, count=count, project=project).  After this
    !! returns, call wandb_config_get to read the sampled hyperparameters,
    !! run your training loop, then call wandb_finish.
    !!
    !! @param sweep_id  Sweep ID returned by wandb_sweep.
    !! @param project   wandb project name.
    !! @param count     Number of runs to execute (0 = until sweep is done).
    !! @param entity    Optional wandb entity/team.
    character(len=*),           intent(in) :: sweep_id
    character(len=*),           intent(in) :: project
    integer,          optional, intent(in) :: count
    character(len=*), optional, intent(in) :: entity

    character(len=:), allocatable :: c_entity
    integer(c_int) :: rc, c_count

    c_count = 0_c_int
    if(present(count)) c_count = int(count, c_int)

    if(present(entity))then
       c_entity = entity // c_null_char
    else
       c_entity = c_null_char
    end if

    rc = wandb_agent_c( &
         sweep_id // c_null_char, &
         project  // c_null_char, &
         c_entity,                &
         c_count                  &
    )

    if(rc /= 0)then
       write(0,*) "[wf] WARNING: wandb_agent failed (rc=", rc, ")"
    end if

  end subroutine wandb_agent


  !-----------------------------------------------------------------------------
  ! wandb_sweep_start_agent
  !-----------------------------------------------------------------------------
  subroutine wandb_sweep_start_agent(sweep_id, project, count, entity)
    !! Start a wandb sweep agent in a background Python thread.
    !!
    !! ## Usage pattern
    !!
    !! ```fortran
    !! call wandb_sweep(config=cfg, project="my-proj", sweep_id=sid)
    !! call wandb_sweep_start_agent(sid, "my-proj", count=5)
    !! do i = 1, 5
    !!    call wandb_sweep_next_params(params_json)
    !!    ! parse params_json, train, log ...
    !!    call wandb_sweep_run_done()
    !! end do
    !! call wandb_shutdown()
    !! ```
    character(len=*),           intent(in) :: sweep_id
    character(len=*),           intent(in) :: project
    integer,                    intent(in) :: count
    character(len=*), optional, intent(in) :: entity

    character(len=:), allocatable :: c_entity
    integer(c_int) :: rc

    if(present(entity))then
       c_entity = trim(entity) // c_null_char
    else
       c_entity = c_null_char
    end if

    rc = wandb_sweep_start_agent_c( &
         trim(sweep_id) // c_null_char, &
         trim(project)  // c_null_char, &
         c_entity,                       &
         int(count, c_int)               &
    )

    if(rc /= 0)then
       write(0,*) "[wf] WARNING: wandb_sweep_start_agent failed (rc=", rc, ")"
    end if

  end subroutine wandb_sweep_start_agent


  !-----------------------------------------------------------------------------
  ! wandb_sweep_next_params
  !-----------------------------------------------------------------------------
  subroutine wandb_sweep_next_params(params_json, timeout_s)
    !! Block until the sweep agent has started the next run and populated
    !! wandb.config with the sweep-sampled hyperparameters.
    !!
    !! @param params_json  Receives the sampled hyperparameters as a JSON
    !!                     object string, e.g. {"lr":0.01,"hidden":32}.
    !! @param timeout_s    Seconds to wait before giving up (default: 120.0).
    character(len=*),          intent(out) :: params_json
    real(c_double), optional,  intent(in)  :: timeout_s

    character(kind=c_char), dimension(4096) :: c_buf
    real(c_double) :: tval
    integer(c_int) :: ok
    integer :: i

    tval = 120.0_c_double
    if(present(timeout_s)) tval = timeout_s

    ok = wandb_sweep_params_c(c_buf, int(4096, c_int), tval)

    if(ok == 0_c_int)then
       write(0,*) "[wf] WARNING: wandb_sweep_next_params timed out."
       params_json = '{}'
       return
    end if

    params_json = ' '
    do i = 1, len(params_json)
       if(c_buf(i) == c_null_char) exit
       params_json(i:i) = c_buf(i)
    end do

  end subroutine wandb_sweep_next_params


  !-----------------------------------------------------------------------------
  ! wandb_sweep_run_done
  !-----------------------------------------------------------------------------
  subroutine wandb_sweep_run_done()
    !! Signal that the current sweep run's training is finished.
    !!
    !! The sweep agent callback will call wandb.finish() and request the next
    !! set of hyperparameters from the sweep controller.
    !! Call this once per sweep run, after all logging is done.
    call wandb_sweep_run_done_c()
  end subroutine wandb_sweep_run_done


  !=============================================================================
  ! wandb_sweep_config_type — type-bound procedure implementations
  !=============================================================================

  subroutine swcfg_set_method(self, method)
    !! Set the sweep search method: "bayes" | "grid" | "random".
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: method

    self%method_str = trim(method)

  end subroutine swcfg_set_method


  subroutine swcfg_set_metric(self, name, goal)
    !! Set the optimisation metric.
    !! @param name  Metric key logged via wandb_log (e.g. "val_loss").
    !! @param goal  "minimize" or "maximize".
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: name
    character(len=*),               intent(in)    :: goal

    self%metric_name_str = trim(name)
    self%metric_goal_str = trim(goal)

  end subroutine swcfg_set_metric


  !---------------------------------------------------------------------------
  ! add_param_range  (real32 / real64)
  !---------------------------------------------------------------------------
  subroutine swcfg_add_range_r32(self, name, min_val, max_val, distribution)
    !! Add a continuous hyperparameter with min/max bounds (single precision).
    !! @param distribution  wandb distribution string (default "uniform").
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: name
    real(c_float),                  intent(in)    :: min_val, max_val
    character(len=*), optional,     intent(in)    :: distribution

    call swcfg_add_range_r64(self, name, &
         real(min_val, c_double), real(max_val, c_double), distribution)

  end subroutine swcfg_add_range_r32


  subroutine swcfg_add_range_r64(self, name, min_val, max_val, distribution)
    !! Add a continuous hyperparameter with min/max bounds (double precision).
    !! @param distribution  wandb distribution string (default "uniform").
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: name
    real(c_double),                 intent(in)    :: min_val, max_val
    character(len=*), optional,     intent(in)    :: distribution

    character(len=:), allocatable :: distrib, fragment

    if(present(distribution))then
       distrib = trim(distribution)
    else
       distrib = "uniform"
    end if

    fragment = '"' // trim(name) // '":{' // &
               '"distribution":"' // distrib // '",' // &
               '"min":' // r64_to_json(min_val) // ',' // &
               '"max":' // r64_to_json(max_val) // '}'

    if(.not. allocated(self%params_buf)) self%params_buf = ''
    if(self%num_params > 0)then
       self%params_buf = self%params_buf // ',' // fragment
    else
       self%params_buf = self%params_buf // fragment
    end if
    self%num_params = self%num_params + 1

  end subroutine swcfg_add_range_r64


  !---------------------------------------------------------------------------
  ! add_param_values  (integer array)
  !---------------------------------------------------------------------------
  subroutine swcfg_add_values_int(self, name, values)
    !! Add a discrete integer hyperparameter.
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: name
    integer,                        intent(in)    :: values(:)

    character(len=:), allocatable :: arr, fragment
    character(len=32) :: buf
    integer :: i

    arr = '['
    do i = 1, size(values)
       write(buf, '(i0)') values(i)
       if(i > 1) arr = arr // ','
       arr = arr // trim(buf)
    end do
    arr = arr // ']'

    fragment = '"' // trim(name) // '":{"values":' // arr // '}'
    if(.not. allocated(self%params_buf)) self%params_buf = ''
    if(self%num_params > 0)then
       self%params_buf = self%params_buf // ',' // fragment
    else
       self%params_buf = self%params_buf // fragment
    end if
    self%num_params = self%num_params + 1

  end subroutine swcfg_add_values_int


  !---------------------------------------------------------------------------
  ! add_param_values  (real32 array)
  !---------------------------------------------------------------------------
  subroutine swcfg_add_values_r32(self, name, values)
    !! Add a discrete real-valued hyperparameter (single precision).
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: name
    real(c_float),                  intent(in)    :: values(:)

    real(c_double), allocatable :: tmp(:)
    integer :: i

    allocate(tmp(size(values)))
    do i = 1, size(values)
       tmp(i) = real(values(i), c_double)
    end do
    call swcfg_add_values_r64(self, name, tmp)

  end subroutine swcfg_add_values_r32


  !---------------------------------------------------------------------------
  ! add_param_values  (real64 array)
  !---------------------------------------------------------------------------
  subroutine swcfg_add_values_r64(self, name, values)
    !! Add a discrete real-valued hyperparameter (double precision).
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: name
    real(c_double),                 intent(in)    :: values(:)

    character(len=:), allocatable :: arr, fragment
    integer :: i

    arr = '['
    do i = 1, size(values)
       if(i > 1) arr = arr // ','
       arr = arr // r64_to_json(values(i))
    end do
    arr = arr // ']'

    fragment = '"' // trim(name) // '":{"values":' // arr // '}'
    if(.not. allocated(self%params_buf)) self%params_buf = ''
    if(self%num_params > 0)then
       self%params_buf = self%params_buf // ',' // fragment
    else
       self%params_buf = self%params_buf // fragment
    end if
    self%num_params = self%num_params + 1

  end subroutine swcfg_add_values_r64


  !---------------------------------------------------------------------------
  ! add_param_values  (string array)
  !---------------------------------------------------------------------------
  subroutine swcfg_add_values_str(self, name, values)
    !! Add a discrete string hyperparameter.
    !! All elements of values must have the same declared length; trailing
    !! spaces are trimmed automatically.
    class(wandb_sweep_config_type), intent(inout) :: self
    character(len=*),               intent(in)    :: name
    character(len=*),               intent(in)    :: values(:)

    character(len=:), allocatable :: arr, fragment
    integer :: i

    arr = '['
    do i = 1, size(values)
       if(i > 1) arr = arr // ','
       arr = arr // '"' // trim(values(i)) // '"'
    end do
    arr = arr // ']'

    fragment = '"' // trim(name) // '":{"values":' // arr // '}'
    if(.not. allocated(self%params_buf)) self%params_buf = ''
    if(self%num_params > 0)then
       self%params_buf = self%params_buf // ',' // fragment
    else
       self%params_buf = self%params_buf // fragment
    end if
    self%num_params = self%num_params + 1

  end subroutine swcfg_add_values_str


  !---------------------------------------------------------------------------
  ! to_json
  !---------------------------------------------------------------------------
  function swcfg_to_json(self) result(json)
    !! Serialise the configuration to a JSON string suitable for wandb_sweep.
    class(wandb_sweep_config_type), intent(in) :: self
    character(len=:), allocatable :: json

    character(len=:), allocatable :: method, metric_part, params_part

    if(allocated(self%method_str))then
       method = self%method_str
    else
       method = 'bayes'
    end if

    if(allocated(self%metric_name_str))then
       metric_part = ',"metric":{"name":"' // self%metric_name_str // &
                     '","goal":"'           // self%metric_goal_str // '"}'
    else
       metric_part = ''
    end if

    if(allocated(self%params_buf) .and. len(self%params_buf) > 0)then
       params_part = ',"parameters":{' // self%params_buf // '}'
    else
       params_part = ',"parameters":{}'
    end if

    json = '{"method":"' // trim(method) // '"' // &
            metric_part // params_part // '}'

  end function swcfg_to_json


  !---------------------------------------------------------------------------
  ! Private helper: format a real64 as a compact JSON number string.
  !---------------------------------------------------------------------------
  function r64_to_json(x) result(s)
    real(c_double), intent(in) :: x
    character(len=32) :: s
    integer :: i

    write(s, '(ES23.8E3)') x
    s = adjustl(s)
    do i = 1, len_trim(s)
       if(s(i:i) == 'E') s(i:i) = 'e'
    end do
    s = trim(s)

  end function r64_to_json

end module wf_sweep