Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: offset column role in Task #1225

Merged
merged 34 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
032c4b4
add assert_scorable() to NAMESPACE
bblodfon Dec 4, 2024
b271502
add offset to reflections
bblodfon Dec 4, 2024
76ae35a
add offset col_role to Task
bblodfon Dec 4, 2024
0f1b465
update doc
bblodfon Dec 4, 2024
b43a1e4
add test for offset
bblodfon Dec 4, 2024
aaf3f38
add John as ctb
bblodfon Dec 4, 2024
dbc591d
update news
bblodfon Dec 4, 2024
adbbcbf
Merge branch 'main' into add_offset
be-marc Dec 20, 2024
dbe05cc
...
be-marc Jan 16, 2025
8e3e5d6
Merge branch 'main' into add_offset
be-marc Jan 16, 2025
ed58474
Merge branch 'tmp' into add_offset
be-marc Jan 16, 2025
d88e493
...
be-marc Jan 16, 2025
afce34c
...
be-marc Jan 16, 2025
a8e0469
...
be-marc Jan 16, 2025
339c4ab
...
be-marc Jan 16, 2025
fd729db
...
be-marc Jan 16, 2025
3943c9e
...
be-marc Jan 20, 2025
caae259
...
be-marc Jan 20, 2025
a24b6ac
add offset learner property
bblodfon Jan 22, 2025
0806838
add offset field + tests + doc
bblodfon Jan 22, 2025
e380fd4
add warning during training when task has offset but learner doesn't …
bblodfon Jan 22, 2025
8bd0ab7
add documentation for offset learner property
bblodfon Jan 31, 2025
e9d651e
update news
bblodfon Jan 31, 2025
54772ce
better doc
bblodfon Feb 7, 2025
3f9da4b
add row_id in the output data.table
bblodfon Feb 7, 2025
dac498f
update tests
bblodfon Feb 7, 2025
abbab81
refine doc
bblodfon Feb 7, 2025
0c1fc7d
refine doc on the offset field
bblodfon Feb 7, 2025
9c21f84
add offset task in autotest
bblodfon Feb 7, 2025
d29bc81
add offset task in autotest
bblodfon Feb 7, 2025
3c89e77
generate multiple offset columns for multiclass tasks
bblodfon Feb 10, 2025
9281637
Update R/Task.R
be-marc Feb 10, 2025
02b8aa6
Merge branch 'main' into add_offset
be-marc Feb 10, 2025
2dbb6d0
add namespace for set_names()
bblodfon Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ Authors@R:
comment = c(ORCID = "0000-0002-8115-0400")),
person("Sebastian", "Fischer", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0002-9609-3197")),
person("Lona", "Koers", , "[email protected]", role = "ctb")
person("Lona", "Koers", , "[email protected]", role = "ctb"),
person("John", "Zobolas", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0002-3609-8674"))
)
Description: Efficient, object-oriented programming on the
building blocks of machine learning. Provides 'R6' objects for tasks,
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# mlr3 (development version)

* feat: add new `col_role` offset in `Task` and offset `Learner` property.
A warning is produced if a learner that doesn't support offsets is trained with a task that has an offset column.
* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
Expand Down
53 changes: 48 additions & 5 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ Task = R6Class("Task",
}

# columns with these roles must be present in data
mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order")
mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order", "offset")
mandatory_cols = unlist(private$.col_roles[mandatory_roles], use.names = FALSE)
missing_cols = setdiff(mandatory_cols, data$colnames)
if (length(missing_cols)) {
Expand Down Expand Up @@ -896,6 +896,7 @@ Task = R6Class("Task",
#' * `"strata"`: The task is resampled using one or more stratification variables (role `"stratum"`).
#' * `"groups"`: The task comes with grouping/blocking information (role `"group"`).
#' * `"weights"`: The task comes with observation weights (role `"weight"`).
#' * `"offset"`: The task includes one or more offset columns specifying fixed adjustments for model training and possibly for prediction (role `"offset"`).
#' * `"ordered"`: The task has columns which define the row order (role `"order"`).
#'
#' Note that above listed properties are calculated from the `$col_roles` and may not be set explicitly.
Expand All @@ -907,6 +908,7 @@ Task = R6Class("Task",
if (length(col_roles$group)) "groups" else NULL,
if (length(col_roles$stratum)) "strata" else NULL,
if (length(col_roles$weight)) "weights" else NULL,
if (length(col_roles$offset)) "offset" else NULL,
if (length(col_roles$order)) "ordered" else NULL
)
} else {
Expand Down Expand Up @@ -951,6 +953,10 @@ Task = R6Class("Task",
#' Not more than a single column can be associated with this role.
#' * `"stratum"`: Stratification variables. Multiple discrete columns may have this role.
#' * `"weight"`: Observation weights. Not more than one numeric column may have this role.
#' * `"offset"`: Offset values specifying fixed adjustments for model training.
#' These values can be used to provide baseline predictions from an existing model for updating another model.
#' Some learners require an offset for each target class in a multiclass setting.
#' In this case, the offset columns must be named `"offset_{target_class_name}"`.
#'
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
#' `col_roles` is a named list whose elements are named by column role and each element is a `character()` vector of column names.
#' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots).
Expand Down Expand Up @@ -1084,6 +1090,23 @@ Task = R6Class("Task",
setnames(data, c("row_id", "weight"))[]
},

#' @field offset ([data.table::data.table()])\cr
#' Provides the offset column(s) if the task has a column designated with the role `"offset"`.
#'
#' For regression or binary classification tasks, this returns a single-column offset.
#' For multiclass tasks, it may return multiple offset columns, one for each target class.
#'
#' If there are no columns with the `"offset"` role, `NULL` is returned.
offset = function(rhs) {
assert_has_backend(self)
assert_ro_binding(rhs)
offset_cols = private$.col_roles$offset
if (length(offset_cols) == 0L) {
return(NULL)
}

self$backend$data(private$.row_roles$use, offset_cols)
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
},

#' @field labels (named `character()`)\cr
#' Retrieve `labels` (prettier formated names) from columns.
Expand Down Expand Up @@ -1250,6 +1273,17 @@ task_check_col_roles.Task = function(task, new_roles, ...) {
}
}

# check offset
if (length(new_roles[["offset"]]) && any(fget(task$col_info, new_roles[["offset"]], "type", key = "id") %nin% c("numeric", "integer"))) {
stopf("Offset column(s) %s must be a numeric or integer column", paste0("'", new_roles[["offset"]], "'", collapse = ","))
}

if (any(task$missings(cols = new_roles[["offset"]]) > 0)) {
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
missings = task$missings(cols = new_roles[["offset"]])
missings = names(missings[missings > 0])
stopf("Offset column(s) %s contain missing values", paste0("'", missings, "'", collapse = ","))
}

return(new_roles)
}

Expand All @@ -1266,16 +1300,25 @@ task_check_col_roles.TaskClassif = function(task, new_roles, ...) {
stopf("Target column(s) %s must be a factor or ordered factor", paste0("'", new_roles[["target"]], "'", collapse = ","))
}

if (length(new_roles[["offset"]]) > 1L && length(task$class_names) == 2L) {
stop("There may only be up to one column with role 'offset' for binary classification tasks")
}

if (length(new_roles[["offset"]]) > 1L) {
expected_names = paste0("offset_", task$class_names)
expect_subset(new_roles[["offset"]], expected_names, label = "col_roles")
}

NextMethod()
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.TaskRegr = function(task, new_roles, ...) {

# check target
if (length(new_roles[["target"]]) > 1L) {
stopf("There may only be up to one column with role 'target'")
for (role in c("target", "offset")) {
if (length(new_roles[[role]]) > 1L) {
stopf("There may only be up to one column with role '%s'", role)
}
}

if (length(new_roles[["target"]]) && any(fget(task$col_info, new_roles[["target"]], "type", key = "id") %nin% c("numeric", "integer"))) {
Expand Down
5 changes: 5 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ assert_task_learner = function(task, learner, cols = NULL) {
}
}

if ("offset" %in% task$properties && "offset" %nin% learner$properties) {
warningf("Task '%s' has offset, but learner '%s' does not support this, so it will be ignored",
task$id, learner$id)
}

tmp = mlr_reflections$task_mandatory_properties[[task$task_type]]
if (length(tmp)) {
tmp = setdiff(intersect(task$properties, tmp), learner$properties)
Expand Down
8 changes: 4 additions & 4 deletions R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ local({
"use"
)

tmp = c("feature", "target", "name", "order", "stratum", "group", "weight")
tmp = c("feature", "target", "name", "order", "stratum", "group", "weight", "offset")
mlr_reflections$task_col_roles = list(
regr = tmp,
classif = tmp,
unsupervised = c("feature", "name", "order")
)

tmp = c("strata", "groups", "weights")
tmp = c("strata", "groups", "weights", "offset")
mlr_reflections$task_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp,
Expand All @@ -114,11 +114,11 @@ local({

mlr_reflections$task_print_col_roles = list(
before = character(),
after = c("Order by" = "order", "Strata" = "stratum", "Groups" = "group", "Weights" = "weight")
after = c("Order by" = "order", "Strata" = "stratum", "Groups" = "group", "Weights" = "weight", "Offset" = "offset")
)

### Learner
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal")
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal", "offset")
mlr_reflections$learner_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp
Expand Down
2 changes: 2 additions & 0 deletions man-roxygen/param_learner_properties.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#' The following properties are currently standardized and understood by learners in \CRANpkg{mlr3}:
#' * `"missings"`: The learner can handle missing values in the data.
#' * `"weights"`: The learner supports observation weights.
#' * `"offset"`: The learner can incorporate offset values to adjust predictions.
#' * `"importance"`: The learner supports extraction of importance scores, i.e. comes with an `$importance()` extractor function (see section on optional extractors in [Learner]).
be-marc marked this conversation as resolved.
Show resolved Hide resolved
#' * `"selected_features"`: The learner supports extraction of the set of selected features, i.e. comes with a `$selected_features()` extractor function (see section on optional extractors in [Learner]).
#' * `"oob_error"`: The learner supports extraction of estimated out of bag error, i.e. comes with a `oob_error()` extractor function (see section on optional extractors in [Learner]).
#' * `"validation"`: The learner can use a validation task during training.
#' * `"internal_tuning"`: The learner is able to internally optimize hyperparameters (those are also tagged with `"internal_tuning"`).
#' * `"marshal"`: To save learners with this property, you need to call `$marshal()` first.
#' If a learner is in a marshaled state, you call first need to call `$unmarshal()` to use its model, e.g. for prediction.
#'
1 change: 1 addition & 0 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr3-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,18 @@ test_that("stratify works", {
})

test_that("groups/weights work", {
b = as_data_backend(data.table(x = runif(20), y = runif(20), w = runif(20), g = sample(letters[1:2], 20, replace = TRUE)))
b = as_data_backend(data.table(x = runif(20), y = runif(20), w = runif(20),
o = runif(20), g = sample(letters[1:2], 20, replace = TRUE)))
task = TaskRegr$new("test", b, target = "y")
task$set_row_roles(16:20, character())

expect_false("groups" %chin% task$properties)
expect_false("weights" %chin% task$properties)
expect_false("offset" %chin% task$properties)
expect_null(task$groups)
expect_null(task$weights)

# weight
task$col_roles$weight = "w"
expect_subset("weights", task$properties)
expect_data_table(task$weights, ncols = 2, nrows = 15)
Expand All @@ -265,6 +268,7 @@ test_that("groups/weights work", {
task$col_roles$weight = character()
expect_true("weights" %nin% task$properties)

# group
task$col_roles$group = "g"
expect_subset("groups", task$properties)
expect_data_table(task$groups, ncols = 2, nrows = 15)
Expand Down Expand Up @@ -726,3 +730,4 @@ test_that("warn when internal valid task has 0 obs", {
task = tsk("iris")
expect_warning({task$internal_valid_task = 151}, "has 0 observations")
})

50 changes: 50 additions & 0 deletions tests/testthat/test_TaskClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,53 @@ test_that("target is encoded as factor (#629)", {
dt$target = ordered(dt$target)
TaskClassif$new(id = "XX", backend = dt, target = "target")
})

test_that("offset column role works with binary tasks", {
task = tsk("pima")
expect_null(task$offset)

task$set_col_roles("age", "offset")
expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 1)
expect_subset("age", names(task$offset))

expect_error({
task$col_roles$offset = c("glucose", "diabetes")
}, "There may only be up to one column with role")

expect_error({
task$col_roles$offset = c("glucose")
}, "contain missing values")

expect_warning(lrn("classif.rpart")$train(task), "has offset")
})

test_that("offset column role works with multiclass tasks", {
task = tsk("penguins")
expect_null(task$offset)
task$set_col_roles("year", "offset")
expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 1)
expect_subset("year", names(task$offset))

expect_error({
task$col_roles$offset = "bill_length"
}, "contain missing values")

task = tsk("wine")

expect_error({
task$col_roles$offset = c("alcohol", "ash")
}, "Must be a subset of")

task = tsk("wine")
data = task$data()
set(data, j = "offset_1", value = runif(nrow(data)))
set(data, j = "offset_2", value = runif(nrow(data)))
task = as_task_classif(data, target = "type")
task$set_col_roles(c("offset_1", "offset_2"), "offset")

expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 2)
expect_subset(c("offset_1", "offset_2"), names(task$offset))
})
18 changes: 18 additions & 0 deletions tests/testthat/test_TaskRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,21 @@ test_that("$add_strata", {
task$add_strata(task$target_names, bins = 2)
expect_identical(task$strata$N, c(50L, 10L))
})

test_that("offset column role works", {
task = tsk("mtcars")
expect_null(task$offset)
task$set_col_roles("am", "offset")

expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 1)
expect_subset("am", names(task$offset))

expect_error({
task$col_roles$offset = c("am", "gear")
}, "up to one")

task$col_roles$offset = character()
expect_true("offset" %nin% task$properties)
expect_null(task$offset)
})
bblodfon marked this conversation as resolved.
Show resolved Hide resolved