Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Feb 6, 2025
1 parent 0c8c064 commit 86f5085
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 291 deletions.
2 changes: 0 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Collate:
'OptimizerAsyncSuccessiveHalving.R'
'OptimizerAsyncSuccessiveHalving2.R'
'aaa.R'
'OptimizerBatchSuccessiveHalving.R'
'OptimizerBatchHyperband.R'
'TunerAsyncSuccessiveHalving.R'
'TunerAsyncSuccessiveHalving2.R'
'TunerBatchHyperband.R'
'TunerBatchSuccessiveHalving.R'
'bibentries.R'
Expand Down
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Generated by roxygen2: do not edit by hand

export(OptimizerAsyncSuccessiveHalving)
export(OptimizerAsyncSuccessiveHalving2)
export(OptimizerBatchHyperband)
export(OptimizerBatchSuccessiveHalving)
export(TunerAsyncSuccessiveHalving)
export(TunerAsyncSuccessiveHalving2)
export(TunerBatchHyperband)
export(TunerBatchSuccessiveHalving)
export(hyperband_budget)
Expand Down
179 changes: 133 additions & 46 deletions R/OptimizerAsyncSuccessiveHalving.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ OptimizerAsyncSuccessiveHalving = R6Class("OptimizerAsyncSuccessiveHalving",

# number of stages if each configuration in the first stage uses the minimum budget
# and each configuration in the last stage uses no more than maximum budget
private$.s_max = floor(log(r, eta))
s_max = ceiling(log(r, eta))
# using ceiling can produce one stage too much but floor can produce one stage too few due to floating point errors
# thus we need to check that the last stage is not over the maximum budget
if (r_min * eta^s_max > r_max) s_max = s_max - 1
private$.s_max = s_max

lg$info("Starting successive halving with eta = %g and %i stages from %g to %g %s", eta, private$.s_max + 1, r_min, r_min * eta^s_max, budget_id)

optimize_async_default(inst, self)
}
Expand All @@ -128,67 +134,148 @@ OptimizerAsyncSuccessiveHalving = R6Class("OptimizerAsyncSuccessiveHalving",
.sampler = NULL,

.optimize = function(inst) {
archive = inst$archive
r_min = private$.r_min
s_max = private$.s_max
eta = self$param_set$values$eta
budget_id = inst$search_space$ids(tags = "budget")
direction = inst$archive$codomain$direction
if (inst$archive$codomain$length == 1) {
lg$debug("Using fast algorithm for single-crit with ranks")
optimize_asha_single(inst, self, private)
} else {
lg$debug("Using default algorithm for multi-crit with hypervolume contribution")
optimize_asha_multi(inst, self, private)
}
}
)
)

mlr_optimizers$add("async_successive_halving", OptimizerAsyncSuccessiveHalving)

optimize_asha_multi = function(inst, self, private) {
archive = inst$archive
r_min = private$.r_min
s_max = private$.s_max
eta = self$param_set$values$eta
budget_id = inst$search_space$ids(tags = "budget")
direction = inst$archive$codomain$direction

while (!inst$is_terminated) {
# sample new point xs
xdt = private$.sampler$sample(1)$data
xs = transpose_list(xdt)[[1]]

# add unique id across stages, stage number, and budget
asha_id = UUIDgenerate()
xs = c(xs, list(asha_id = asha_id, stage = 1))
xs[[budget_id]] = private$.r_min

# evaluate
get_private(inst)$.eval_point(xs)

# s_max is 0 if r_min == r_max
if (s_max > 0) {
# iterate stages
for (s in seq(s_max)) {
lg$debug("Fetching results from other workers")

# fetch finished points of current stage
data_stage = archive$finished_data[list(s), , on = "stage"]

# how many configurations can be promoted to the next stage
# at least one configuration must be promotable
n_promotable = max(floor(nrow(data_stage) / eta), 1)

data.table::setDTthreads(1)
lg$debug("%i promotable configurations in stage %i", n_promotable, s)

while (!inst$is_terminated) {
# sample new point xs
xdt = private$.sampler$sample(1)$data
xs = transpose_list(xdt)[[1]]
# get the n best configurations of the current stage
candidates = private$.top_n(data_stage, archive$cols_y, n_promotable, direction)

# add unique id across stages, stage number, and budget
asha_id = UUIDgenerate()
xs = c(xs, list(asha_id = asha_id, stage = 1))
xs[[budget_id]] = private$.r_min
# if xs is not among the best configurations of the current stage draw a new random configuration
if (asha_id %nin% candidates$asha_id) {
lg$debug("Configuration %s is not promotable to stage %i", asha_id, s + 1)
break
}

lg$debug("Configuration %s is promotable to stage %i", asha_id, s + 1)

# increase budget of xs
rs = r_min * eta^s
if (inst$search_space$class[[budget_id]] == "ParamInt") rs = round(rs)
xs[[budget_id]] = rs
xs$stage = s + 1

# evaluate
get_private(inst)$.eval_point(xs)
}
}
}
}

optimize_asha_single = function(inst, self, private) {
archive = inst$archive
r_min = private$.r_min
s_max = private$.s_max
eta = self$param_set$values$eta
budget_id = inst$search_space$ids(tags = "budget")
direction = inst$archive$codomain$direction
r = inst$rush$connector
network_id = inst$rush$network_id

# s_max is 0 if r_min == r_max
if (s_max > 0) {
# iterate stages
for (s in seq(s_max)) {
lg$debug("Fetching results from other workers")
while (!inst$is_terminated) {
# sample new point xs
xdt = private$.sampler$sample(1)$data
xs = transpose_list(xdt)[[1]]

# fetch finished points of current stage
data_stage = archive$finished_data[list(s), , on = "stage"]
# add unique id across stages, stage number, and budget
asha_id = UUIDgenerate()
xs = c(xs, list(asha_id = asha_id, stage = 1))
xs[[budget_id]] = private$.r_min

# how many configurations can be promoted to the next stage
# at least one configuration must be promotable
n_promotable = max(floor(nrow(data_stage) / eta), 1)
# evaluate
ys = get_private(inst)$.eval_point(xs)

lg$debug("%i promotable configurations in stage %i", n_promotable, s)
# add result to leaderboard of the first stage
r$ZADD(sprintf("%s:stage_1", network_id), ys[[1]], asha_id)

# get the n best configurations of the current stage
candidates = private$.top_n(data_stage, archive$cols_y, n_promotable, direction)
# s_max is 0 if r_min == r_max
if (s_max > 0) {
# iterate stages
for (s in seq(s_max)) {
lg$debug("Fetching results")

# if xs is not among the best configurations of the current stage draw a new random configuration
if (asha_id %nin% candidates$asha_id) {
lg$debug("Configuration %s is not promotable to stage %i", asha_id, s + 1)
break
}
# number of configurations in the current stage
n_stage = r$ZCARD(sprintf("%s:stage_%i", network_id, s))

lg$debug("Configuration %s is promotable to stage %i", asha_id, s + 1)
# how many configurations can be promoted to the next stage
# at least one configuration must be promotable
n_promotable = max(floor(n_stage / eta), 1)

# increase budget of xs
rs = r_min * eta^s
if (inst$search_space$class[[budget_id]] == "ParamInt") rs = round(rs)
xs[[budget_id]] = rs
xs$stage = s + 1
lg$debug("%i promotable configurations in stage %i", n_promotable, s)

# evaluate
get_private(inst)$.eval_point(xs)
}
# get rank of xs in the current stage
# ranks are 0-based and in ascending order
rank = r$ZRANK(sprintf("%s:stage_%i", network_id, s), asha_id)

if ((direction == 1 && rank >= n_promotable) || (direction == -1 && rank < n_stage - n_promotable)) {
lg$debug("Configuration %s with rank %i of %i is not promotable to stage %i", asha_id, rank, n_stage, s + 1)
break
}

lg$debug("Configuration %s with rank %i of %i is promotable to stage %i", asha_id, rank, n_stage, s + 1, rank)

# increase budget of xs
rs = r_min * eta^s
if (inst$search_space$class[[budget_id]] == "ParamInt") rs = round(rs)
xs[[budget_id]] = rs
xs$stage = s + 1

# evaluate
ys = get_private(inst)$.eval_point(xs)

# add result to leaderboard of the next stage
r$ZADD(sprintf("%s:stage_%i", network_id, s + 1), ys[[1]], asha_id)
}
}
)
)
}

mlr_optimizers$add("async_successive_halving", OptimizerAsyncSuccessiveHalving)
# remove leaderboards
walk(seq(s_max + 1), function(s) {
r$DEL(sprintf("%s:stage_%i", network_id, s))
})
}
Loading

0 comments on commit 86f5085

Please sign in to comment.