From 626868732c27d62a33dc0821f26fac438986c9cd Mon Sep 17 00:00:00 2001 From: Lennart Schneider Date: Mon, 12 Aug 2024 16:01:40 +0200 Subject: [PATCH 1/4] feat: support callbacks in AcqOptimizer --- DESCRIPTION | 2 +- R/AcqOptimizer.R | 13 +++++++++---- man/AcqOptimizer.Rd | 10 +++++++--- tests/testthat/test_AcqOptimizer.R | 29 +++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 2a58168e..5cf62f96 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -72,7 +72,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: yes Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Collate: 'mlr_acqfunctions.R' 'AcqFunction.R' diff --git a/R/AcqOptimizer.R b/R/AcqOptimizer.R index 68d6cd84..9cc848c1 100644 --- a/R/AcqOptimizer.R +++ b/R/AcqOptimizer.R @@ -71,7 +71,7 @@ #' #' acq_function = acqf("ei", surrogate = surrogate) #' -#' acq_function$surrogate$update() +#' acq_function$surrogate$update( #' acq_function$update() #' #' acq_optimizer = acqo( @@ -84,7 +84,7 @@ AcqOptimizer = R6Class("AcqOptimizer", public = list( - #' @field optimizer ([bbotk::Optimizer]). + #' @field optimizer ([bbotk::OptimizerBatch]). optimizer = NULL, #' @field terminator ([bbotk::Terminator]). @@ -93,16 +93,21 @@ AcqOptimizer = R6Class("AcqOptimizer", #' @field acq_function ([AcqFunction]). acq_function = NULL, + #' @field callbacks (`NULL` | list of [bbotk::CallbackBatch]). + callbacks = NULL, + #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' #' @param optimizer ([bbotk::Optimizer]). #' @param terminator ([bbotk::Terminator]). #' @param acq_function (`NULL` | [AcqFunction]). - initialize = function(optimizer, terminator, acq_function = NULL) { + #' @param callbacks (`NULL` | list of [bbotk::CallbackBatch]) + initialize = function(optimizer, terminator, acq_function = NULL, callbacks = NULL) { self$optimizer = assert_r6(optimizer, "Optimizer") self$terminator = assert_r6(terminator, "Terminator") self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE) + self$callbacks = assert_callbacks(as_callbacks(callbacks)) ps = ps( n_candidates = p_int(lower = 1, default = 1L), logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"), @@ -146,7 +151,7 @@ AcqOptimizer = R6Class("AcqOptimizer", logger$set_threshold(self$param_set$values$logging_level) on.exit(logger$set_threshold(old_threshold)) - instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE) + instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks) # warmstart if (self$param_set$values$warmstart) { diff --git a/man/AcqOptimizer.Rd b/man/AcqOptimizer.Rd index f0f8d9b3..76894502 100644 --- a/man/AcqOptimizer.Rd +++ b/man/AcqOptimizer.Rd @@ -75,7 +75,7 @@ if (requireNamespace("mlr3learners") & acq_function = acqf("ei", surrogate = surrogate) - acq_function$surrogate$update() + acq_function$surrogate$update( acq_function$update() acq_optimizer = acqo( @@ -89,11 +89,13 @@ if (requireNamespace("mlr3learners") & \section{Public fields}{ \if{html}{\out{
}} \describe{ -\item{\code{optimizer}}{(\link[bbotk:Optimizer]{bbotk::Optimizer}).} +\item{\code{optimizer}}{(\link[bbotk:OptimizerBatch]{bbotk::OptimizerBatch}).} \item{\code{terminator}}{(\link[bbotk:Terminator]{bbotk::Terminator}).} \item{\code{acq_function}}{(\link{AcqFunction}).} + +\item{\code{callbacks}}{(\code{NULL} | list of \link[bbotk:CallbackBatch]{bbotk::CallbackBatch}).} } \if{html}{\out{
}} } @@ -124,7 +126,7 @@ Set of hyperparameters.} \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AcqOptimizer$new(optimizer, terminator, acq_function = NULL)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AcqOptimizer$new(optimizer, terminator, acq_function = NULL, callbacks = NULL)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -135,6 +137,8 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \item{\code{terminator}}{(\link[bbotk:Terminator]{bbotk::Terminator}).} \item{\code{acq_function}}{(\code{NULL} | \link{AcqFunction}).} + +\item{\code{callbacks}}{(\code{NULL} | list of \link[bbotk:CallbackBatch]{bbotk::CallbackBatch})} } \if{html}{\out{}} } diff --git a/tests/testthat/test_AcqOptimizer.R b/tests/testthat/test_AcqOptimizer.R index 213963fe..69e2ca2f 100644 --- a/tests/testthat/test_AcqOptimizer.R +++ b/tests/testthat/test_AcqOptimizer.R @@ -120,3 +120,32 @@ test_that("AcqOptimizer deep clone", { expect_true(address(acqopt1$terminator) != address(acqopt2$terminator)) }) +test_that("AcqOptimizer callbacks", { + domain = ps(x = p_dbl(lower = 10, upper = 20, trafo = function(x) x - 15)) + objective = ObjectiveRFunDt$new( + fun = function(xdt) data.table(y = xdt$x ^ 2), + domain = domain, + codomain = ps(y = p_dbl(tags = "minimize")), + check_values = FALSE + ) + instance = MAKE_INST(objective = objective, search_space = domain, terminator = trm("evals", n_evals = 5L)) + design = MAKE_DESIGN(instance) + instance$eval_batch(design) + callback = callback_batch("mlr3mbo.acqopt_time", + on_optimization_begin = function(callback, context) { + callback$state$begin = Sys.time() + }, + on_optimization_end = function(callback, context) { + callback$state$end = Sys.time() + attr(callback$state$outer_instance, "acq_opt_runtime") = as.numeric(callback$state$end - callback$state$begin) + } + ) + callback$state$outer_instance = instance + acqfun = AcqFunctionEI$new(SurrogateLearner$new(REGR_FEATURELESS, archive = instance$archive)) + acqopt = AcqOptimizer$new(opt("random_search", batch_size = 10L), trm("evals", n_evals = 10L), acq_function = acqfun, callbacks = callback) + acqfun$surrogate$update() + acqfun$update() + res = acqopt$optimize() + expect_number(attr(instance, "acq_opt_runtime")) +}) + From cc523462199baf6481357e59dbdfd9b116e282ce Mon Sep 17 00:00:00 2001 From: Lennart Schneider Date: Mon, 12 Aug 2024 16:20:15 +0200 Subject: [PATCH 2/4] typo --- R/AcqOptimizer.R | 8 ++++---- man/AcqOptimizer.Rd | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/R/AcqOptimizer.R b/R/AcqOptimizer.R index 9cc848c1..f9f0f5d6 100644 --- a/R/AcqOptimizer.R +++ b/R/AcqOptimizer.R @@ -71,7 +71,7 @@ #' #' acq_function = acqf("ei", surrogate = surrogate) #' -#' acq_function$surrogate$update( +#' acq_function$surrogate$update() #' acq_function$update() #' #' acq_optimizer = acqo( @@ -84,7 +84,7 @@ AcqOptimizer = R6Class("AcqOptimizer", public = list( - #' @field optimizer ([bbotk::OptimizerBatch]). + #' @field optimizer ([bbotk::Optimizer]). optimizer = NULL, #' @field terminator ([bbotk::Terminator]). @@ -93,7 +93,7 @@ AcqOptimizer = R6Class("AcqOptimizer", #' @field acq_function ([AcqFunction]). acq_function = NULL, - #' @field callbacks (`NULL` | list of [bbotk::CallbackBatch]). + #' @field callbacks (`NULL` | list of [mlr3misc::Callback]). callbacks = NULL, #' @description @@ -102,7 +102,7 @@ AcqOptimizer = R6Class("AcqOptimizer", #' @param optimizer ([bbotk::Optimizer]). #' @param terminator ([bbotk::Terminator]). #' @param acq_function (`NULL` | [AcqFunction]). - #' @param callbacks (`NULL` | list of [bbotk::CallbackBatch]) + #' @param callbacks (`NULL` | list of [mlr3misc::Callback]) initialize = function(optimizer, terminator, acq_function = NULL, callbacks = NULL) { self$optimizer = assert_r6(optimizer, "Optimizer") self$terminator = assert_r6(terminator, "Terminator") diff --git a/man/AcqOptimizer.Rd b/man/AcqOptimizer.Rd index 76894502..b4c143b1 100644 --- a/man/AcqOptimizer.Rd +++ b/man/AcqOptimizer.Rd @@ -75,7 +75,7 @@ if (requireNamespace("mlr3learners") & acq_function = acqf("ei", surrogate = surrogate) - acq_function$surrogate$update( + acq_function$surrogate$update() acq_function$update() acq_optimizer = acqo( @@ -89,13 +89,13 @@ if (requireNamespace("mlr3learners") & \section{Public fields}{ \if{html}{\out{
}} \describe{ -\item{\code{optimizer}}{(\link[bbotk:OptimizerBatch]{bbotk::OptimizerBatch}).} +\item{\code{optimizer}}{(\link[bbotk:Optimizer]{bbotk::Optimizer}).} \item{\code{terminator}}{(\link[bbotk:Terminator]{bbotk::Terminator}).} \item{\code{acq_function}}{(\link{AcqFunction}).} -\item{\code{callbacks}}{(\code{NULL} | list of \link[bbotk:CallbackBatch]{bbotk::CallbackBatch}).} +\item{\code{callbacks}}{(\code{NULL} | list of \link[mlr3misc:Callback]{mlr3misc::Callback}).} } \if{html}{\out{
}} } @@ -138,7 +138,7 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \item{\code{acq_function}}{(\code{NULL} | \link{AcqFunction}).} -\item{\code{callbacks}}{(\code{NULL} | list of \link[bbotk:CallbackBatch]{bbotk::CallbackBatch})} +\item{\code{callbacks}}{(\code{NULL} | list of \link[mlr3misc:Callback]{mlr3misc::Callback})} } \if{html}{\out{}} } From fc3b4ab8875cf5d03a98c0728f3dfab4c67094a6 Mon Sep 17 00:00:00 2001 From: Lennart Schneider Date: Mon, 12 Aug 2024 17:18:00 +0200 Subject: [PATCH 3/4] acqo sugar --- R/sugar.R | 8 +++++--- man/acqo.Rd | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/R/sugar.R b/R/sugar.R index ac9279c6..69c6a8b0 100644 --- a/R/sugar.R +++ b/R/sugar.R @@ -20,7 +20,7 @@ #' @param cols_y (`NULL` | `character()`)\cr #' Column id(s) in the [bbotk::Archive] that should be used as a target. #' If a list of [mlr3::LearnerRegr] is provided as the `learner` argument and `cols_y` is -#' specified as well, as many column names as learners must be provided. +#' specified as well, as many column names as learners must be provided. #' Can also be `NULL` in which case this is automatically inferred based on the archive. #' @param ... (named `list()`)\cr #' Named arguments passed to the constructor, to be set as parameters in the @@ -90,6 +90,8 @@ acqf = function(.key, ...) { #' @param acq_function (`NULL` | [AcqFunction])\cr #' [AcqFunction] that is to be used. #' Can also be `NULL`. +#' @param callbacks (`NULL` | list of [mlr3misc::Callback]) +#' Callbacks used during acquisition function optimization. #' @param ... (named `list()`)\cr #' Named arguments passed to the constructor, to be set as parameters in the #' [paradox::ParamSet]. @@ -101,9 +103,9 @@ acqf = function(.key, ...) { #' library(bbotk) #' acqo(opt("random_search"), trm("evals"), catch_errors = FALSE) #' @export -acqo = function(optimizer, terminator, acq_function = NULL, ...) { +acqo = function(optimizer, terminator, acq_function = NULL, callbacks = NULL, ...) { dots = list(...) - acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function) + acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function, callbacks = callbacks) acqopt$param_set$values = insert_named(acqopt$param_set$values, dots) acqopt } diff --git a/man/acqo.Rd b/man/acqo.Rd index 9455152e..43114a0c 100644 --- a/man/acqo.Rd +++ b/man/acqo.Rd @@ -4,7 +4,7 @@ \alias{acqo} \title{Syntactic Sugar Acquisition Function Optimizer Construction} \usage{ -acqo(optimizer, terminator, acq_function = NULL, ...) +acqo(optimizer, terminator, acq_function = NULL, callbacks = NULL, ...) } \arguments{ \item{optimizer}{(\link[bbotk:Optimizer]{bbotk::Optimizer})\cr @@ -17,6 +17,9 @@ acqo(optimizer, terminator, acq_function = NULL, ...) \link{AcqFunction} that is to be used. Can also be \code{NULL}.} +\item{callbacks}{(\code{NULL} | list of \link[mlr3misc:Callback]{mlr3misc::Callback}) +Callbacks used during acquisition function optimization.} + \item{...}{(named \code{list()})\cr Named arguments passed to the constructor, to be set as parameters in the \link[paradox:ParamSet]{paradox::ParamSet}.} From a5e6c0219abdf49a4ca7b3ec616c177c2446027f Mon Sep 17 00:00:00 2001 From: Lennart Schneider Date: Tue, 13 Aug 2024 17:46:42 +0200 Subject: [PATCH 4/4] remotes bbotk --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index 5cf62f96..ffc831b7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -66,6 +66,7 @@ Suggests: rpart, stringi, testthat (>= 3.0.0) +Remotes: mlr-org/bbotk ByteCompile: no Encoding: UTF-8 Config/testthat/edition: 3