Skip to content

Commit

Permalink
more use of shared parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 25, 2019
1 parent de02626 commit 2d230ff
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 27 deletions.
17 changes: 17 additions & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
, "boost"
, "boosting_type"
)
, "early_stopping_round" = c(
"early_stopping_round"
, "early_stopping_rounds"
, "early_stopping"
, "n_iter_no_change"
)
, "metric" = c(
"metric"
, "metrics"
Expand All @@ -21,5 +27,16 @@
"num_class"
, "num_classes"
)
, "num_iterations" = c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
))
}
16 changes: 3 additions & 13 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,7 @@ lgb.cv <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1
}
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_trees <- c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
n_rounds <- .PARAMETER_ALIASES()[["num_iterations"]]
if (any(names(params) %in% n_trees)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
} else {
Expand Down Expand Up @@ -227,7 +217,7 @@ lgb.cv <- function(params = list(),

# If early stopping was passed as a parameter in params(), prefer that to keyword argument
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
early_stop_param_indx <- names(params) %in% early_stop
if (any(early_stop_param_indx)) {
first_early_stop_param <- which(early_stop_param_indx)[[1]]
Expand All @@ -238,7 +228,7 @@ lgb.cv <- function(params = list(),
# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)

boosting_param_names <- c("boosting", .PARAMETER_ALIASES()[["boosting"]])
boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
using_dart <- any(
sapply(
X = boosting_param_names
Expand Down
16 changes: 3 additions & 13 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,7 @@ lgb.train <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1
}
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_rounds <- c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
n_rounds <- .PARAMETER_ALIASES()[["num_iterations"]]
if (any(names(params) %in% n_rounds)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1
} else {
Expand Down Expand Up @@ -209,7 +199,7 @@ lgb.train <- function(params = list(),

# If early stopping was passed as a parameter in params(), prefer that to keyword argument
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
early_stop_param_indx <- names(params) %in% early_stop
if (any(early_stop_param_indx)) {
first_early_stop_param <- which(early_stop_param_indx)[[1]]
Expand All @@ -220,7 +210,7 @@ lgb.train <- function(params = list(),
# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)

boosting_param_names <- c("boosting", .PARAMETER_ALIASES()[["boosting"]])
boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
using_dart <- any(
sapply(
X = boosting_param_names
Expand Down
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ expect_true(".PARAMETER_ALIASES() returns a named list", {
})

expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
for (boosting_param in c("boosting", .PARAMETER_ALIASES()[["boosting"]])){
for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]){
expect_warning({
result <- lightgbm(
data = train$data
Expand Down

0 comments on commit 2d230ff

Please sign in to comment.