Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Disabled early stopping when using 'dart' boosting strategy #2443

Merged
merged 11 commits into from
Oct 25, 2019
42 changes: 42 additions & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Central location for parameter aliases.
# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters

# [description] List of respected parameter aliases. Wrapped in a function to take advantage of
# lazy evaluation (so it doesn't matter what order R sources files during installation).
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
# vector of corresponding aliases.
.PARAMETER_ALIASES <- function(){
return(list(
"boosting" = c(
"boosting"
, "boost"
, "boosting_type"
)
, "early_stopping_round" = c(
"early_stopping_round"
, "early_stopping_rounds"
, "early_stopping"
, "n_iter_no_change"
)
, "metric" = c(
"metric"
, "metrics"
, "metric_types"
)
, "num_class" = c(
"num_class"
, "num_classes"
)
, "num_iterations" = c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
))
}
6 changes: 5 additions & 1 deletion R-package/R/callback.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ cb.reset.parameters <- function(new_params) {

# Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos
not_allowed <- c("num_class", "metric", "boosting_type")
not_allowed <- c(
.PARAMETER_ALIASES()[["num_class"]]
, .PARAMETER_ALIASES()[["metric"]]
, .PARAMETER_ALIASES()[["boosting"]]
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
)
if (any(pnames %in% not_allowed)) {
stop(
"Parameters "
Expand Down
80 changes: 46 additions & 34 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,7 @@ lgb.cv <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1
}
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_trees <- c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
if (any(names(params) %in% n_trees)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
} else {
Expand Down Expand Up @@ -225,30 +215,52 @@ lgb.cv <- function(params = list(),
callbacks <- add.cb(callbacks, cb.record.evaluation())
}

# Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
if (any(names(params) %in% early_stop)) {
if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
params[[which(names(params) %in% early_stop)[1]]]
, verbose = verbose
)
)
}
} else {
if (!is.null(early_stopping_rounds)) {
if (early_stopping_rounds > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
early_stopping_rounds
, verbose = verbose
)
)
# If early stopping was passed as a parameter in params(), prefer that to keyword argument
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
early_stop_param_indx <- names(params) %in% early_stop
if (any(early_stop_param_indx)) {
first_early_stop_param <- which(early_stop_param_indx)[[1]]
first_early_stop_param_name <- names(params)[[first_early_stop_param]]
early_stopping_rounds <- params[[first_early_stop_param_name]]
}

# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)

boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
using_dart <- any(
sapply(
X = boosting_param_names
, FUN = function(param){
identical(params[[param]], 'dart')
}
}
)
)

# Cannot use early stopping with 'dart' boosting
if (using_dart){
warning("Early stopping is not available in 'dart' mode.")
using_early_stopping_via_args <- FALSE

# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}

# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, verbose = verbose
)
)
}

# Categorize callbacks
Expand Down
85 changes: 48 additions & 37 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,13 @@ lgb.train <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1
}
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_rounds <- c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
if (any(names(params) %in% n_rounds)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1
n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
if (any(names(params) %in% n_trees)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
} else {
end_iteration <- begin_iteration + nrounds - 1
}


# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object")
Expand Down Expand Up @@ -207,30 +196,52 @@ lgb.train <- function(params = list(),
callbacks <- add.cb(callbacks, cb.record.evaluation())
}

# Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
if (any(names(params) %in% early_stop)) {
if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
params[[which(names(params) %in% early_stop)[1]]]
, verbose = verbose
)
)
}
} else {
if (!is.null(early_stopping_rounds)) {
if (early_stopping_rounds > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
early_stopping_rounds
, verbose = verbose
)
)
# If early stopping was passed as a parameter in params(), prefer that to keyword argument
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
early_stop_param_indx <- names(params) %in% early_stop
if (any(early_stop_param_indx)) {
first_early_stop_param <- which(early_stop_param_indx)[[1]]
first_early_stop_param_name <- names(params)[[first_early_stop_param]]
early_stopping_rounds <- params[[first_early_stop_param_name]]
}

# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)

boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
using_dart <- any(
sapply(
X = boosting_param_names
, FUN = function(param){
identical(params[[param]], 'dart')
}
}
)
)

# Cannot use early stopping with 'dart' boosting
if (using_dart){
warning("Early stopping is not available in 'dart' mode.")
using_early_stopping_via_args <- FALSE

# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}

# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, verbose = verbose
)
)
}

# "Categorize" callbacks
Expand Down
32 changes: 32 additions & 0 deletions R-package/tests/testthat/test_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,35 @@ test_that("Feature penalties work properly", {
# Ensure that feature is not used when feature_penalty = 0
expect_length(var_gain[[length(var_gain)]], 0)
})

expect_true(".PARAMETER_ALIASES() returns a named list", {
param_aliases <- .PARAMETER_ALIASES()
expect_true(is.list(param_aliases))
expect_true(is.character(names(param_aliases)))
expect_true(is.character(param_aliases[["boosting"]]))
expect_true(is.character(param_aliases[["early_stopping_round"]]))
expect_true(is.character(param_aliases[["metric"]]))
expect_true(is.character(param_aliases[["num_class"]]))
expect_true(is.character(param_aliases[["num_iterations"]]))
})

expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]){
expect_warning({
result <- lightgbm(
data = train$data
, label = train$label
, num_leaves = 5
, learning_rate = 0.05
, nrounds = 5
, objective = "binary"
, metric = "binary_error"
, verbose = -1
, params = stats::setNames(
object = "dart"
, nm = boosting_param
)
)
}, regexp = "Early stopping is not available in 'dart' mode")
}
})