From 6221426aab70a153f9e866464f21069e8a200485 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 29 Sep 2019 23:02:49 -0500 Subject: [PATCH] added support for boosting aliases in R package and test on skipping early boosting --- R-package/R/aliases.R | 15 ++++++++++++ R-package/R/callback.R | 7 +++++- R-package/R/lgb.cv.R | 12 +++++++++- R-package/R/lgb.train.R | 12 +++++++++- R-package/tests/testthat/test_parameters.R | 27 ++++++++++++++++++++++ 5 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 R-package/R/aliases.R diff --git a/R-package/R/aliases.R b/R-package/R/aliases.R new file mode 100644 index 000000000000..80d411f074b1 --- /dev/null +++ b/R-package/R/aliases.R @@ -0,0 +1,15 @@ +# Central location for paramter aliases. +# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters + +# [description] List of respected parameter aliases. Wrapped in a function to take advantage of +# lazy evaluation (so it doesn't matter what order R sources files during installation). +# [return] A named list, where each key is a main LightGBM parameter and each value is a character +# vector of corresponding aliases. +.PARAMETER_ALIASES <- function(){ + return(list( + "boosting" = c( + "boosting_type" + , "boost" + ) + )) +} diff --git a/R-package/R/callback.R b/R-package/R/callback.R index 92bd9c035a97..d6a010b19296 100644 --- a/R-package/R/callback.R +++ b/R-package/R/callback.R @@ -37,7 +37,12 @@ cb.reset.parameters <- function(new_params) { # Some parameters are not allowed to be changed, # since changing them would simply wreck some chaos - not_allowed <- c("num_class", "metric", "boosting_type") + not_allowed <- c( + "num_class" + , "metric" + , "boosting" + , .PARAMETER_ALIASES()[["boosting"]] + ) if (any(pnames %in% not_allowed)) { stop( "Parameters " diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index a7766798b1e8..980b712b9148 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -238,8 +238,18 @@ lgb.cv <- function(params = list(), # Did user pass parameters that indicate they want to use early stopping? using_early_stopping_via_args <- !is.null(early_stopping_rounds) + boosting_param_names <- c("boosting", .PARAMETER_ALIASES()[["boosting"]]) + using_dart <- any( + sapply( + X = boosting_param_names + , FUN = function(param){ + identical(params[[param]], 'dart') + } + ) + ) + # Cannot use early stopping with 'dart' boosting - if (identical(params$boosting, "dart")){ + if (using_dart){ warning("Early stopping is not available in 'dart' mode.") using_early_stopping_via_args <- FALSE diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index ce9358137b74..1900637fcdfc 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -220,8 +220,18 @@ lgb.train <- function(params = list(), # Did user pass parameters that indicate they want to use early stopping? using_early_stopping_via_args <- !is.null(early_stopping_rounds) + boosting_param_names <- c("boosting", .PARAMETER_ALIASES()[["boosting"]]) + using_dart <- any( + sapply( + X = boosting_param_names + , FUN = function(param){ + identical(params[[param]], 'dart') + } + ) + ) + # Cannot use early stopping with 'dart' boosting - if (identical(params$boosting, "dart")){ + if (using_dart){ warning("Early stopping is not available in 'dart' mode.") using_early_stopping_via_args <- FALSE diff --git a/R-package/tests/testthat/test_parameters.R b/R-package/tests/testthat/test_parameters.R index 60a762de2e59..9d6c75c861e5 100644 --- a/R-package/tests/testthat/test_parameters.R +++ b/R-package/tests/testthat/test_parameters.R @@ -43,3 +43,30 @@ test_that("Feature penalties work properly", { # Ensure that feature is not used when feature_penalty = 0 expect_length(var_gain[[length(var_gain)]], 0) }) + +expect_true(".PARAMETER_ALIASES() returns a named list", { + param_aliases <- .PARAMETER_ALIASES() + expect_true(is.list(param_aliases)) + expect_true(is.character(names(param_aliases))) +}) + +expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", { + for (boosting_param in c("boosting", .PARAMETER_ALIASES()[["boosting"]])){ + expect_warning({ + result <- lightgbm( + data = train$data + , label = train$label + , num_leaves = 5 + , learning_rate = 0.05 + , nrounds = 5 + , objective = "binary" + , metric = "binary_error" + , verbose = -1 + , params = stats::setNames( + object = "dart" + , nm = boosting_param + ) + ) + }, regexp = "Early stopping is not available in 'dart' mode") + } +})