Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support callbacks in AcqOptimizer #153

Merged
merged 4 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@ Suggests:
rpart,
stringi,
testthat (>= 3.0.0)
Remotes: mlr-org/bbotk
ByteCompile: no
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: yes
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Collate:
'mlr_acqfunctions.R'
'AcqFunction.R'
Expand Down
9 changes: 7 additions & 2 deletions R/AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,21 @@ AcqOptimizer = R6Class("AcqOptimizer",
#' @field acq_function ([AcqFunction]).
acq_function = NULL,

#' @field callbacks (`NULL` | list of [mlr3misc::Callback]).
callbacks = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param optimizer ([bbotk::Optimizer]).
#' @param terminator ([bbotk::Terminator]).
#' @param acq_function (`NULL` | [AcqFunction]).
initialize = function(optimizer, terminator, acq_function = NULL) {
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
initialize = function(optimizer, terminator, acq_function = NULL, callbacks = NULL) {
self$optimizer = assert_r6(optimizer, "Optimizer")
self$terminator = assert_r6(terminator, "Terminator")
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
self$callbacks = assert_callbacks(as_callbacks(callbacks))
ps = ps(
n_candidates = p_int(lower = 1, default = 1L),
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
Expand Down Expand Up @@ -146,7 +151,7 @@ AcqOptimizer = R6Class("AcqOptimizer",
logger$set_threshold(self$param_set$values$logging_level)
on.exit(logger$set_threshold(old_threshold))

instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE)
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)

# warmstart
if (self$param_set$values$warmstart) {
Expand Down
8 changes: 5 additions & 3 deletions R/sugar.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#' @param cols_y (`NULL` | `character()`)\cr
#' Column id(s) in the [bbotk::Archive] that should be used as a target.
#' If a list of [mlr3::LearnerRegr] is provided as the `learner` argument and `cols_y` is
#' specified as well, as many column names as learners must be provided.
#' specified as well, as many column names as learners must be provided.
#' Can also be `NULL` in which case this is automatically inferred based on the archive.
#' @param ... (named `list()`)\cr
#' Named arguments passed to the constructor, to be set as parameters in the
Expand Down Expand Up @@ -90,6 +90,8 @@ acqf = function(.key, ...) {
#' @param acq_function (`NULL` | [AcqFunction])\cr
#' [AcqFunction] that is to be used.
#' Can also be `NULL`.
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
#' Callbacks used during acquisition function optimization.
#' @param ... (named `list()`)\cr
#' Named arguments passed to the constructor, to be set as parameters in the
#' [paradox::ParamSet].
Expand All @@ -101,9 +103,9 @@ acqf = function(.key, ...) {
#' library(bbotk)
#' acqo(opt("random_search"), trm("evals"), catch_errors = FALSE)
#' @export
acqo = function(optimizer, terminator, acq_function = NULL, ...) {
acqo = function(optimizer, terminator, acq_function = NULL, callbacks = NULL, ...) {
dots = list(...)
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function)
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function, callbacks = callbacks)
acqopt$param_set$values = insert_named(acqopt$param_set$values, dots)
acqopt
}
Expand Down
6 changes: 5 additions & 1 deletion man/AcqOptimizer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/acqo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test_AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,32 @@ test_that("AcqOptimizer deep clone", {
expect_true(address(acqopt1$terminator) != address(acqopt2$terminator))
})

test_that("AcqOptimizer callbacks", {
domain = ps(x = p_dbl(lower = 10, upper = 20, trafo = function(x) x - 15))
objective = ObjectiveRFunDt$new(
fun = function(xdt) data.table(y = xdt$x ^ 2),
domain = domain,
codomain = ps(y = p_dbl(tags = "minimize")),
check_values = FALSE
)
instance = MAKE_INST(objective = objective, search_space = domain, terminator = trm("evals", n_evals = 5L))
design = MAKE_DESIGN(instance)
instance$eval_batch(design)
callback = callback_batch("mlr3mbo.acqopt_time",
on_optimization_begin = function(callback, context) {
callback$state$begin = Sys.time()
},
on_optimization_end = function(callback, context) {
callback$state$end = Sys.time()
attr(callback$state$outer_instance, "acq_opt_runtime") = as.numeric(callback$state$end - callback$state$begin)
}
)
callback$state$outer_instance = instance
acqfun = AcqFunctionEI$new(SurrogateLearner$new(REGR_FEATURELESS, archive = instance$archive))
acqopt = AcqOptimizer$new(opt("random_search", batch_size = 10L), trm("evals", n_evals = 10L), acq_function = acqfun, callbacks = callback)
acqfun$surrogate$update()
acqfun$update()
res = acqopt$optimize()
expect_number(attr(instance, "acq_opt_runtime"))
})

Loading