Skip to content

Commit

Permalink
using new paradox syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Jan 14, 2024
1 parent f746630 commit f2e0a1c
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 29 deletions.
8 changes: 4 additions & 4 deletions R/AcqFunctionEHVIGH.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ AcqFunctionEHVIGH = R6Class("AcqFunctionEHVIGH",
assert_r6(surrogate, "SurrogateLearnerCollection", null.ok = TRUE)
assert_int(k, lower = 2L)

constants = ParamSet$new(list(
ParamInt$new("k", lower = 2L, default = 15L),
ParamDbl$new("r", lower = 0, upper = 1, default = 0.2)
))
constants = ps(
k = p_int(lower = 2L, default = 15L),
r = p_dbl(lower = 0, upper = 1, default = 0.2)
)
constants$values$k = k
constants$values$r = r

Expand Down
8 changes: 4 additions & 4 deletions R/AcqFunctionSmsEgo.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ AcqFunctionSmsEgo = R6Class("AcqFunctionSmsEgo",
assert_number(lambda, lower = 1, finite = TRUE)
assert_number(epsilon, lower = 0, finite = TRUE, null.ok = TRUE)

constants = ParamSet$new(list(
ParamDbl$new("lambda", lower = 0, default = 1),
ParamDbl$new("epsilon", lower = 0, default = NULL, special_vals = list(NULL)) # for NULL, it will be calculated dynamically
))
constants = ps(
lambda = p_dbl(lower = 0, default = 1),
epsilon = p_dbl(lower = 0, default = NULL, special_vals = list(NULL)) # for NULL, it will be calculated dynamically
)
constants$values$lambda = lambda
constants$values$epsilon = epsilon

Expand Down
14 changes: 7 additions & 7 deletions R/AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ AcqOptimizer = R6Class("AcqOptimizer",
self$optimizer = assert_r6(optimizer, "Optimizer")
self$terminator = assert_r6(terminator, "Terminator")
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
ps = ParamSet$new(list(
ParamInt$new("n_candidates", lower = 1, default = 1L),
ParamFct$new("logging_level", levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
ParamLgl$new("warmstart", default = FALSE),
ParamInt$new("warmstart_size", lower = 1L, special_vals = list("all")),
ParamLgl$new("skip_already_evaluated", default = TRUE),
ParamLgl$new("catch_errors", default = TRUE))
ps = ps(
n_candidates = p_int(lower = 1, default = 1L),
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
warmstart = p_lgl(default = FALSE),
warmstart_size = p_int(lower = 1L, special_vals = list("all")),
skip_already_evaluated = p_lgl(default = TRUE),
catch_errors = p_lgl(default = TRUE))
)
ps$values = list(n_candidates = 1, logging_level = "warn", warmstart = FALSE, skip_already_evaluated = TRUE, catch_errors = TRUE)
ps$add_dep("warmstart_size", on = "warmstart", cond = CondEqual$new(TRUE))
Expand Down
10 changes: 5 additions & 5 deletions R/SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ SurrogateLearner = R6Class("SurrogateLearner",
assert_character(cols_x, min.len = 1L, null.ok = TRUE)
assert_string(col_y, null.ok = TRUE)

ps = ParamSet$new(list(
ParamLgl$new("assert_insample_perf"),
ParamUty$new("perf_measure", custom_check = function(x) check_r6(x, classes = "MeasureRegr")), # FIXME: actually want check_measure
ParamDbl$new("perf_threshold", lower = -Inf, upper = Inf),
ParamLgl$new("catch_errors"))
ps = ps(
assert_insample_perf = p_lgl(),
perf_measure = p_uty(custom_check = function(x) check_r6(x, classes = "MeasureRegr")), # FIXME: actually want check_measure
perf_threshold = p_dbl(lower = -Inf, upper = Inf),
catch_errors = p_lgl())
)
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE)
ps$add_dep("perf_measure", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
Expand Down
10 changes: 5 additions & 5 deletions R/SurrogateLearnerCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
assert_character(cols_x, min.len = 1L, null.ok = TRUE)
assert_character(cols_y, len = length(learners), null.ok = TRUE)

ps = ParamSet$new(list(
ParamLgl$new("assert_insample_perf"),
ParamUty$new("perf_measures", custom_check = function(x) check_list(x, types = "MeasureRegr", any.missing = FALSE, len = length(learners))), # FIXME: actually want check_measures
ParamUty$new("perf_thresholds", custom_check = function(x) check_double(x, lower = -Inf, upper = Inf, any.missing = FALSE, len = length(learners))),
ParamLgl$new("catch_errors"))
ps = ps(
assert_insample_perf = p_lgl(),
perf_measures = p_uty(custom_check = function(x) check_list(x, types = "MeasureRegr", any.missing = FALSE, len = length(learners))), # FIXME: actually want check_measures
perf_thresholds = p_uty(custom_check = function(x) check_double(x, lower = -Inf, upper = Inf, any.missing = FALSE, len = length(learners))),
catch_errors = p_lgl())
)
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE)
ps$add_dep("perf_measures", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
Expand Down
5 changes: 1 addition & 4 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ generate_acq_codomain = function(surrogate, id, direction = "same") {
} else {
tags = direction
}
codomain = ParamSet$new(list(
ParamDbl$new(id, tags = tags)
))
codomain
do.call(ps, structure(list(p_dbl(tags = tags)), names = id))
}

generate_acq_domain = function(surrogate) {
Expand Down

0 comments on commit f2e0a1c

Please sign in to comment.