diff --git a/R-package/R/aliases.R b/R-package/R/aliases.R index 8ae11e87c504..c7ecf8f72410 100644 --- a/R-package/R/aliases.R +++ b/R-package/R/aliases.R @@ -12,6 +12,12 @@ , "boost" , "boosting_type" ) + , "early_stopping_round" = c( + "early_stopping_round" + , "early_stopping_rounds" + , "early_stopping" + , "n_iter_no_change" + ) , "metric" = c( "metric" , "metrics" @@ -21,5 +27,16 @@ "num_class" , "num_classes" ) + , "num_iterations" = c( + "num_iterations" + , "num_iteration" + , "n_iter" + , "num_tree" + , "num_trees" + , "num_round" + , "num_rounds" + , "num_boost_round" + , "n_estimators" + ) )) } diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index 980b712b9148..3a8a57d7c2fe 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -136,17 +136,7 @@ lgb.cv <- function(params = list(), begin_iteration <- predictor$current_iter() + 1 } # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one - n_trees <- c( - "num_iterations" - , "num_iteration" - , "n_iter" - , "num_tree" - , "num_trees" - , "num_round" - , "num_rounds" - , "num_boost_round" - , "n_estimators" - ) + n_rounds <- .PARAMETER_ALIASES()[["num_iterations"]] if (any(names(params) %in% n_trees)) { end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1 } else { @@ -227,7 +217,7 @@ lgb.cv <- function(params = list(), # If early stopping was passed as a parameter in params(), prefer that to keyword argument # early_stopping_rounds by overwriting the value in 'early_stopping_rounds' - early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change") + early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]] early_stop_param_indx <- names(params) %in% early_stop if (any(early_stop_param_indx)) { first_early_stop_param <- which(early_stop_param_indx)[[1]] @@ -238,7 +228,7 @@ lgb.cv <- function(params = list(), # Did user pass parameters that indicate they want to use early stopping? using_early_stopping_via_args <- !is.null(early_stopping_rounds) - boosting_param_names <- c("boosting", .PARAMETER_ALIASES()[["boosting"]]) + boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]] using_dart <- any( sapply( X = boosting_param_names diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index 1900637fcdfc..852ae146c7a2 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -108,17 +108,7 @@ lgb.train <- function(params = list(), begin_iteration <- predictor$current_iter() + 1 } # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one - n_rounds <- c( - "num_iterations" - , "num_iteration" - , "n_iter" - , "num_tree" - , "num_trees" - , "num_round" - , "num_rounds" - , "num_boost_round" - , "n_estimators" - ) + n_rounds <- .PARAMETER_ALIASES()[["num_iterations"]] if (any(names(params) %in% n_rounds)) { end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1 } else { @@ -209,7 +199,7 @@ lgb.train <- function(params = list(), # If early stopping was passed as a parameter in params(), prefer that to keyword argument # early_stopping_rounds by overwriting the value in 'early_stopping_rounds' - early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change") + early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]] early_stop_param_indx <- names(params) %in% early_stop if (any(early_stop_param_indx)) { first_early_stop_param <- which(early_stop_param_indx)[[1]] @@ -220,7 +210,7 @@ lgb.train <- function(params = list(), # Did user pass parameters that indicate they want to use early stopping? using_early_stopping_via_args <- !is.null(early_stopping_rounds) - boosting_param_names <- c("boosting", .PARAMETER_ALIASES()[["boosting"]]) + boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]] using_dart <- any( sapply( X = boosting_param_names diff --git a/R-package/tests/testthat/test_parameters.R b/R-package/tests/testthat/test_parameters.R index beedbe452ba6..a82fb0df7c6a 100644 --- a/R-package/tests/testthat/test_parameters.R +++ b/R-package/tests/testthat/test_parameters.R @@ -54,7 +54,7 @@ expect_true(".PARAMETER_ALIASES() returns a named list", { }) expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", { - for (boosting_param in c("boosting", .PARAMETER_ALIASES()[["boosting"]])){ + for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]){ expect_warning({ result <- lightgbm( data = train$data