Skip to content

Commit

Permalink
Merge pull request #122 from mlr-org/fix_lightgbm_categorical
Browse files Browse the repository at this point in the history
fix categorical_features in lightgbm
  • Loading branch information
RaphaelS1 authored Oct 20, 2021
2 parents def5074 + 650d05d commit df6b209
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3extralearners
Title: Extra Learners For mlr3
Version: 0.5.12
Version: 0.5.13
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3extralearners 0.5.13

* Fix `categorical_features` in {lightgbm} learners

# mlr3extralearners 0.5.12

* Patch for `lightgbm` updates
Expand Down
11 changes: 9 additions & 2 deletions R/learner_lightgbm_classif_lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#' @details
#' For categorical features either pre-process data by encoding columns or
#' specify the categorical columns with the `categorical_feature` parameter.
#' For this learner please do not prefix the categorical feature with `name:`.
#'
#' @template seealso_learner
#' @template example
Expand Down Expand Up @@ -285,9 +286,12 @@ LearnerClassifLightGBM = R6Class("LearnerClassifLightGBM",
dtrain = lightgbm::lgb.Dataset(
data = as.matrix(task$data(rows = train_ids, cols = task$feature_names)),
label = train_label,
free_raw_data = FALSE
free_raw_data = FALSE,
categorical_feature = pars$categorical_feature
)

pars$categorical_feature <- NULL

dtest = lightgbm::lgb.Dataset.create.valid(
dataset = dtrain,
data = as.matrix(task$data(rows = valid_ids, cols = task$feature_names)),
Expand Down Expand Up @@ -318,9 +322,12 @@ LearnerClassifLightGBM = R6Class("LearnerClassifLightGBM",
dtrain = lightgbm::lgb.Dataset(
data = as.matrix(task$data(cols = task$feature_names)),
label = train_label,
free_raw_data = FALSE
free_raw_data = FALSE,
categorical_feature = pars$categorical_feature
)

pars$categorical_feature <- NULL

if ("weights" %in% task$properties) {
dtrain$setinfo("weight", task$weights$weight)
}
Expand Down
12 changes: 9 additions & 3 deletions R/learner_lightgbm_regr_lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
#' - Adjusted default: -1L
#' - Reason for change: Prevents accidental conflicts with mlr messaging system.
#'
#' @details
#' For categorical features either pre-process data by encoding columns or
#' specify the categorical columns with the `categorical_feature` parameter.
#' For this learner please do not prefix the categorical feature with `name:`.
#'
#' @template seealso_learner
#' @template example
Expand Down Expand Up @@ -233,9 +233,12 @@ LearnerRegrLightGBM = R6Class("LearnerRegrLightGBM",
dtrain = lightgbm::lgb.Dataset(
data = as.matrix(task$data(rows = train_ids, cols = task$feature_names)),
label = as.matrix(task$data(rows = train_ids, cols = task$target_names)),
free_raw_data = FALSE
free_raw_data = FALSE,
categorical_feature = pars$categorical_feature
)

pars$categorical_feature <- NULL

valid_ids = task$row_roles$validation
dtest = lightgbm::lgb.Dataset.create.valid(
dataset = dtrain,
Expand All @@ -258,9 +261,12 @@ LearnerRegrLightGBM = R6Class("LearnerRegrLightGBM",
dtrain = lightgbm::lgb.Dataset(
data = as.matrix(task$data(cols = task$feature_names)),
label = as.matrix(task$data(cols = task$target_names)),
free_raw_data = FALSE
free_raw_data = FALSE,
categorical_feature = pars$categorical_feature
)

pars$categorical_feature <- NULL

if ("weights" %in% task$properties) {
dtrain$setinfo("weight", task$weights$weight)
}
Expand Down

0 comments on commit df6b209

Please sign in to comment.