Skip to content

Commit

Permalink
added protection against use of 'cb.early.stop'
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 25, 2019
1 parent 6cd303a commit b24ecfc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
22 changes: 17 additions & 5 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,25 @@ lgb.cv <- function(params = list(),
early_stopping_rounds <- params[[first_early_stop_param_name]]
}

using_early_stopping <- !is.null(early_stopping_rounds)
if (using_early_stopping && identical(params$boosting, "dart")){
warning("Early stopping is not available in 'dart' mode")
use_early_stopping <- FALSE
# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)

# Cannot use early stopping with 'dart' boosting
if (identical(params$boosting, "dart")){
warning("Early stopping is not available in 'dart' mode.")
using_early_stopping_via_args <- FALSE

# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}

if (using_early_stopping){
# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
Expand Down
18 changes: 15 additions & 3 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,25 @@ lgb.train <- function(params = list(),
early_stopping_rounds <- params[[first_early_stop_param_name]]
}

using_early_stopping <- !is.null(early_stopping_rounds)
# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)

# Cannot use early stopping with 'dart' boosting
if (identical(params$boosting, "dart")){
warning("Early stopping is not available in 'dart' mode.")
use_early_stopping <- FALSE
using_early_stopping_via_args <- FALSE

# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}

if (using_early_stopping){
# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
Expand Down

0 comments on commit b24ecfc

Please sign in to comment.