Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 4, 2024
1 parent 8f7f308 commit 1e00462
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 38 deletions.
15 changes: 11 additions & 4 deletions R/CallbackEvaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,22 @@ CallbackEvaluation= R6Class("CallbackEvaluation",
#'
#' @description
#' Function to create a [CallbackEvaluation].
#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()].
#'
#' Evaluation callbacks are called at different stages of the resampling process.
#' The stages are prefixed with `on_*`.
#'
#' ```
#' Start Evaluation on Worker
#' Start Resampling Iteration on Worker
#' - on_evaluation_begin
#' - on_evaluation_before_train
#' - on_evaluation_before_predict
#' - on_evaluation_end
#' End Evaluation on Worker
#' End Resampling Iteration on Worker
#' ```
#'
#' See also the section on parameters for more information on the stages.
#' A evaluation callback works with [ContextEvaluation].
#
#' @details
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
Expand Down Expand Up @@ -84,15 +91,15 @@ callback_evaluation = function(
on_evaluation_begin,
on_evaluation_before_train,
on_evaluation_before_predict,
on_evaluation_end ),
on_evaluation_end),
c(
"on_evaluation_begin",
"on_evaluation_before_train",
"on_evaluation_before_predict",
"on_evaluation_end"
)), is.null)

walk(stages, function(stage) assert_function(stage, args = c("callback", "context")))
stages = map(stages, function(stage) crate(assert_function(stage, args = c("callback", "context"))))
callback = CallbackEvaluation$new(id, label, man)
iwalk(stages, function(stage, name) callback[[name]] = stage)
callback
Expand Down
26 changes: 3 additions & 23 deletions R/ContextEvaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,10 @@ ContextEvaluation = R6Class("ContextEvaluation",
#' The resampling is unchanged during the evaluation.
resampling = NULL,

#' @field param_values `list()`\cr
#' The parameter values to be used.
#' Is usually only set while tuning.
param_values = NULL,

#' @field iteration (`integer()`)\cr
#' The current iteration.
iteration = NULL,

#' @field sets (`list()`)\cr
#' The train and test set.
#' The sets are available on stage `on_evaluation_before_train``.
sets = NULL,

#' @field test_set (`integer()`)\cr
#' Validation test set.
#' The set is only available when using internal validation.
test_set = NULL,

#' @field predict_sets (`list()`)\cr
#' The prediction sets stored in `learner$predict_sets`.
#' The sets are available on stage `on_evaluation_before_predict`.
predict_sets = NULL,

#' @field pdatas (List of [PredictionData])\cr
#' The prediction data.
#' The data is available on stage `on_evaluation_end`.
Expand All @@ -68,9 +48,9 @@ ContextEvaluation = R6Class("ContextEvaluation",
#' The learner to be evaluated.
#' @param resampling ([Resampling])\cr
#' The resampling strategy to be used.
#' @param param_values (`list()`)\cr
#' The parameter values to be used.
initialize = function(task, learner, resampling, param_values, iteration) {
#' @param iteration (`integer()`)\cr
#' The current iteration.
initialize = function(task, learner, resampling, iteration) {
# no assertions to avoid overhead
self$task = task
self$learner = learner
Expand Down
16 changes: 8 additions & 8 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ workhorse = function(
unmarshal = TRUE,
callbacks = NULL
) {
ctx = ContextEvaluation$new(task, learner, resampling, param_values, iteration)
ctx = ContextEvaluation$new(task, learner, resampling, iteration)

call_back("on_evaluation_begin", callbacks, ctx)

Expand Down Expand Up @@ -308,7 +308,7 @@ workhorse = function(
lg$info("%s learner '%s' on task '%s' (iter %i/%i)",
if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters)

ctx$sets = list(
sets = list(
train = resampling$train_set(iteration),
test = resampling$test_set(iteration)
)
Expand All @@ -323,11 +323,11 @@ workhorse = function(

validate = get0("validate", learner)

ctx$test_set = if (identical(validate, "test")) ctx$sets$test
ctx$test_set = if (identical(validate, "test")) sets$test

call_back("on_evaluation_before_train", callbacks, ctx)

train_result = learner_train(learner, task, ctx$sets[["train"]], ctx$test_set, mode = mode)
train_result = learner_train(learner, task, sets[["train"]], ctx$test_set, mode = mode)
ctx$learner = learner = train_result$learner

# process the model so it can be used for prediction (e.g. marshal for callr prediction), but also
Expand All @@ -338,20 +338,20 @@ workhorse = function(
)

# predict for each set
ctx$predict_sets = learner$predict_sets
predict_sets = learner$predict_sets

# creates the tasks and row_ids for all selected predict sets
pred_data = prediction_tasks_and_sets(task, train_result, validate, ctx$sets, ctx$predict_sets)
pred_data = prediction_tasks_and_sets(task, train_result, validate, sets, predict_sets)

call_back("on_evaluation_before_predict", callbacks, ctx)

pdatas = Map(function(set, row_ids, task) {
lg$debug("Creating Prediction for predict set '%s'", set)

learner_predict(learner, task, row_ids)
}, set = ctx$predict_sets, row_ids = pred_data$sets, task = pred_data$tasks)
}, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks)

if (!length(ctx$predict_sets)) {
if (!length(predict_sets)) {
learner$state$predict_time = 0L
}
ctx$pdatas = discard(pdatas, is.null)
Expand Down
8 changes: 7 additions & 1 deletion man/ContextEvaluation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions man/callback_evaluation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1e00462

Please sign in to comment.