Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Learner): support marshal property #993

Merged
merged 49 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8fb3818
feat(Learner): support bundling property
sebffischer Jan 25, 2024
851e64c
fix(bundle): always (un)bundle for callr encapsulation
sebffischer Jan 25, 2024
1bcc0c7
fix: bundle cannot be present twice in properties
sebffischer Jan 26, 2024
bfb568f
typo
sebffischer Jan 30, 2024
606a5a9
refactor bundling
sebffischer Jan 31, 2024
ea651e2
bundle property must be manually set
sebffischer Jan 31, 2024
7d767de
fix tests
sebffischer Jan 31, 2024
291d450
better docs
sebffischer Jan 31, 2024
ff31f66
fix one more test
sebffischer Jan 31, 2024
2152b75
really fix test
sebffischer Jan 31, 2024
9bd2efe
public methods
sebffischer Jan 31, 2024
cfbab13
refactor
sebffischer Feb 1, 2024
f98424b
Update R/Measure.R
sebffischer Feb 1, 2024
50919d2
Update man-roxygen/param_learner_properties.R
sebffischer Feb 14, 2024
77af5ee
better marshal behavior
sebffischer Feb 20, 2024
c9c9c6a
Update R/Measure.R
sebffischer Feb 21, 2024
86e9163
Update R/Measure.R
sebffischer Feb 21, 2024
e2b1b54
docs
sebffischer Feb 21, 2024
c73f892
better approach
sebffischer Feb 22, 2024
e0c53ea
docs
sebffischer Feb 22, 2024
ffffc30
add clone argument and optimize worker
sebffischer Feb 22, 2024
432971f
optimization
sebffischer Feb 22, 2024
f9b33ea
add marshal property to regr.debug and remove lily
sebffischer Feb 26, 2024
d6ceb1c
inplace
sebffischer Mar 4, 2024
cdf603f
some more fixes
sebffischer Mar 5, 2024
ea2d75c
rename
sebffischer Apr 9, 2024
afbb6b5
...
sebffischer Apr 9, 2024
26c6c81
marshal is property of classif.debug
sebffischer Apr 9, 2024
2f5e685
fix printer and autotest
sebffischer Apr 9, 2024
f8943e7
Merge branch 'main' into bundle
sebffischer Apr 9, 2024
51e5c5f
...
sebffischer Apr 9, 2024
ad8cfb5
typo
sebffischer Apr 9, 2024
9581425
docs
sebffischer Apr 9, 2024
72733a9
refactor
sebffischer Apr 10, 2024
dfc345e
inplace marshal for ResultData
sebffischer Apr 10, 2024
c979ba8
fix class of marshaled classif debug
sebffischer Apr 17, 2024
7f637c8
add class to learner state for marshaling
sebffischer Apr 17, 2024
688f34c
...
sebffischer Apr 17, 2024
3a89b38
Merge branch 'main' into bundle
sebffischer Apr 17, 2024
626a3df
...
sebffischer Apr 17, 2024
5ada07c
typo
sebffischer Apr 17, 2024
3d64b1a
cleanup marshaling
sebffischer Apr 22, 2024
af67642
more cleanup
sebffischer Apr 22, 2024
0e82181
add test
sebffischer Apr 22, 2024
6e87384
more cleanup
sebffischer Apr 22, 2024
ef59a0b
better docs
sebffischer Apr 22, 2024
3707f8c
fix tests for at and glrn
sebffischer Apr 22, 2024
cbf2700
skip some tests until new versions are released
sebffischer Apr 22, 2024
c497a7d
fix test helpers
sebffischer Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
public methods
  • Loading branch information
sebffischer committed Jan 31, 2024
commit 9bd2efe00f7585feada62f91ec67c376e8b00cdc
4 changes: 2 additions & 2 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ ResultData = R6Class("ResultData",
phash = self$data$fact[i, "learner_phash"][[1L]]
learner = self$data$learners[phash, "learner", on = "learner_phash"][[1L]][[1L]]
if (!is.null(state$model) && isFALSE(state$bundled)) {
state$model = get_private(learner)$.bundle(state$model)
state$model = learner$bundle_model(state$model)
state$bundled = TRUE
}
state
Expand All @@ -266,7 +266,7 @@ ResultData = R6Class("ResultData",
phash = self$data$fact[i, "learner_phash"][[1L]]
learner = self$data$learners[phash, "learner", on = "learner_phash"][[1L]][[1L]]
if (!is.null(state$model) && isTRUE(state$bundled)) {
state$model = get_private(learner)$.unbundle(state$model)
state$model = learner$unbundle_model(state$model)
state$bundled = FALSE
}
state
Expand Down
10 changes: 5 additions & 5 deletions R/bundle.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
#' In order to implement bundling for a [`Learner`], you need to add:
#' * the public methods `$bundle()` and `$unbundle()`, where you call `learner_bundle(self)` and
#' `learner_unbundle(self)` respectively.
#' * the active binding `$bundled`, where you simply call `learner_bundled(self)`.
#' * the private method `$.bundle(model)`, which takes in a [`Learner`]'s model and returns it in bundled form,
#' * the public method `$bundle_model(model)`, which takes in a [`Learner`]'s model and returns it in bundled form,
#' without modifying the learner's state. Must not depend on the learner's state.
#' * the private method `$.unbundle(model)`, which takes in a [`Learner`]'s bundled model and returns it in
#' * the public method `$unbundle_model(model)`, which takes in a [`Learner`]'s bundled model and returns it in
#' unbundled form. Must not depend on the learner's state.
#' * the active binding `$bundled`, where you simply call `learner_bundled(self)`.
#' * add the property `bundle` to the learner's properties.
#'
#' To test the bundling implementation, you can use the internal test helper `expect_bundleable()`.
Expand All @@ -47,7 +47,7 @@ learner_unbundle = function(learner) {
if (isFALSE(learner$bundled)) {
warningf("Learner '%s' has not been bundled, skipping.", learner$id)
} else if (isTRUE(learner$bundled)) {
learner$model = get_private(learner)$.unbundle(learner$model)
learner$model = learner$unbundle_model(learner$model)
learner$state$bundled = FALSE
}
invisible(learner)
Expand All @@ -62,7 +62,7 @@ learner_bundle = function(learner) {
if (isTRUE(learner$bundled)) {
warningf("Learner '%s' has already been bundled, skipping.", learner$id)
} else if ("bundle" %in% learner$properties) {
learner$model = get_private(learner)$.bundle(learner$model)
learner$model = learner$bundle_model(learner$model)
learner$state$bundled = TRUE
}
invisible(learner)
Expand Down
2 changes: 1 addition & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
}

if ("bundle" %in% learner$properties && identical(learner$encapsulate[["train"]], "callr")) {
model = get_private(learner)$.bundle(model)
model = learner$bundle_model(model)
}

model
Expand Down
7 changes: 2 additions & 5 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,12 @@ expect_bundleable = function(learner, task) {
has_public = function(learner, x) {
exists(x, learner, inherits = FALSE)
}
has_private = function(learner, x) {
exists(x, mlr3misc::get_private(learner), inherits = FALSE)
}

expect_true(has_public(learner, "bundle") && test_function(learner$bundle, nargs = 0))
expect_true(has_public(learner, "unbundle") && test_function(learner$unbundle, nargs = 0))
expect_true(has_public(learner, "bundle"))
expect_true(has_private(learner, ".bundle") && test_function(mlr3misc::get_private(learner)$.bundle, nargs = 1, args = "model"))
expect_true(has_private(learner, ".unbundle") && test_function(mlr3misc::get_private(learner)$.unbundle, nargs = 1, args = "model"))
expect_true(has_public(learner, "bundle_model") && test_function(learner$bundle_model, nargs = 1, args = "model"))
expect_true(has_public(learner, "unbundle_model") && test_function(learner$unbundle_model, nargs = 1, args = "model"))

expect_false(learner$bundled)

Expand Down
6 changes: 3 additions & 3 deletions man/bundling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,12 @@ test_that("bundling", {
},
unbundle = function() {
learner_unbundle(self)
},
bundle_model = function(model) {
structure(list(model), class = private$.class)
},
unbundle_model = function(model) {
model[[1L]]
}
),
active = list(
Expand All @@ -501,12 +507,6 @@ test_that("bundling", {
}
),
private = list(
.bundle = function(model) {
structure(list(model), class = private$.class)
},
.unbundle = function(model) {
model[[1L]]
},
.class = NULL
)
)
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/test_bundle.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ test_that("bundleable learner behaves as expected", {
},
unbundle = function() {
learner_unbundle(self)
}
),
private = list(
.bundle = function(model) {
},
bundle_model = function(model) {
private$.tmp_model = model
"bundle"
},
.unbundle = function(model) {
unbundle_model = function(model) {
model = private$.tmp_model
private$.tmp_model = NULL
private$.counter = private$.counter + 1
model
},
}
),
private = list(
.tmp_model = NULL,
.counter = 0
),
Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,18 @@ test_that("bundling", {
self$properties = c("bundle", self$properties)
},
bundle = function() learner_bundle(self),
unbundle = function() learner_unbundle(self)
unbundle = function() learner_unbundle(self),
bundle_model = function(model) {
structure(list(model), class = "bundled")
},
unbundle_model = function(model) {
model[[1L]]
}
),
active = list(
bundled = function() learner_bundled(self)
),
private = list(
.bundle = function(model) {
structure(list(model), class = "bundled")
},
.unbundle = function(model) {
model[[1L]]
},
.tmp_model = NULL
)
)
Expand Down