Skip to content

Commit

Permalink
[R-package] added support for first_metric_only (fixes #2368) (#2912)
Browse files Browse the repository at this point in the history
* [R-package] started implementing first_metric_only

* trying stuff

* more changes

* fixed handling of multiple metrics

* fixed tests

* remove duplicate tests

* get training tests

* fixes for lgb.cv()

* fixes for lgb.cv()

* fix linting
  • Loading branch information
jameslamb authored Sep 6, 2020
1 parent 636e4ee commit d4325c5
Show file tree
Hide file tree
Showing 12 changed files with 871 additions and 40 deletions.
14 changes: 14 additions & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,17 @@
)
return(c(learning_params, .DATASET_PARAMETERS()))
}

# [description]
# Per https://github.com/microsoft/LightGBM/blob/master/docs/Parameters.rst#metric,
# a few different strings can be used to indicate "no metrics".
# [returns]
# A character vector
.NO_METRIC_STRINGS <- function() {
return(c(
"na"
, "None"
, "null"
, "custom"
))
}
12 changes: 10 additions & 2 deletions R-package/R/callback.R
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ cb.record.evaluation <- function() {

}

cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
cb.early.stop <- function(stopping_rounds, first_metric_only = FALSE, verbose = TRUE) {

# Initialize variables
factor_to_bigger_better <- NULL
Expand Down Expand Up @@ -325,8 +325,16 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Store iteration
cur_iter <- env$iteration

# By default, any metric can trigger early stopping. This can be disabled
# with 'first_metric_only = TRUE'
if (isTRUE(first_metric_only)) {
evals_to_check <- 1L
} else {
evals_to_check <- seq_len(eval_len)
}

# Loop through evaluation
for (i in seq_len(eval_len)) {
for (i in evals_to_check) {

# Store score
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
Expand Down
38 changes: 27 additions & 11 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ CVBooster <- R6::R6Class(
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param eval evaluation function, can be (list of) character or custom eval function
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
Expand All @@ -52,7 +48,7 @@ CVBooster <- R6::R6Class(
#' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).}
#' }
#'
#' @inheritSection lgb_shared_params Early Stopping
#' @return a trained model \code{lgb.CVBooster}.
#'
#' @examples
Expand Down Expand Up @@ -114,17 +110,25 @@ lgb.cv <- function(params = list()
params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval)
fobj <- NULL
feval <- NULL
eval_functions <- list(NULL)

# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
}

# Check for loss (function or not)
# If loss is a single function, store it as a 1-element list
# (for backwards compatibility). If it is a list of functions, store
# all of them
if (is.function(eval)) {
feval <- eval
eval_functions <- list(eval)
}
if (methods::is(eval, "list")) {
eval_functions <- Filter(
f = is.function
, x = eval
)
}

# Init predictor to empty
Expand Down Expand Up @@ -266,6 +270,7 @@ lgb.cv <- function(params = list()
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, first_metric_only = isTRUE(params[["first_metric_only"]])
, verbose = verbose
)
)
Expand Down Expand Up @@ -357,7 +362,11 @@ lgb.cv <- function(params = list()
# Update one boosting iteration
msg <- lapply(cv_booster$boosters, function(fd) {
fd$booster$update(fobj = fobj)
fd$booster$eval_valid(feval = feval)
out <- list()
for (eval_function in eval_functions) {
out <- append(out, fd$booster$eval_valid(feval = eval_function))
}
return(out)
})

# Prepare collection of evaluation results
Expand All @@ -384,7 +393,13 @@ lgb.cv <- function(params = list()
# When early stopping is not activated, we compute the best iteration / score ourselves
# based on the first first metric
if (record && is.na(env$best_score)) {
first_metric <- cv_booster$boosters[[1L]][[1L]]$.__enclos_env__$private$eval_names[1L]
# when using a custom eval function, the metric name is returned from the
# function, so figure it out from record_evals
if (!is.null(eval_functions[1L])) {
first_metric <- names(cv_booster$record_evals[["valid"]])[1L]
} else {
first_metric <- cv_booster$.__enclos_env__$private$eval_names[1L]
}
.find_best <- which.min
if (isTRUE(env$eval_list[[1L]]$higher_better[1L])) {
.find_best <- which.max
Expand Down Expand Up @@ -576,7 +591,8 @@ lgb.merge.cv.result <- function(msg, showsd = TRUE) {
msg[[i]][[j]]$value }))
})

# Get evaluation
# Get evaluation. Just taking the first element here to
# get structture (name, higher_bettter, data_name)
ret_eval <- msg[[1L]]

# Go through evaluation length items
Expand Down
47 changes: 34 additions & 13 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
#' @description Logic to train with LightGBM
#' @inheritParams lgb_shared_params
#' @param valids a list of \code{lgb.Dataset} objects, used for validation
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param eval evaluation function, can be (a list of) character or custom eval function
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int
Expand All @@ -26,6 +22,7 @@
#' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).}
#' }
#' @inheritSection lgb_shared_params Early Stopping
#' @return a trained booster model \code{lgb.Booster}.
#'
#' @examples
Expand Down Expand Up @@ -90,17 +87,25 @@ lgb.train <- function(params = list(),
params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval)
fobj <- NULL
feval <- NULL
eval_functions <- list(NULL)

# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
}

# Check for loss (function or not)
# If loss is a single function, store it as a 1-element list
# (for backwards compatibility). If it is a list of functions, store
# all of them
if (is.function(eval)) {
feval <- eval
eval_functions <- list(eval)
}
if (methods::is(eval, "list")) {
eval_functions <- Filter(
f = is.function
, x = eval
)
}

# Init predictor to empty
Expand Down Expand Up @@ -235,6 +240,7 @@ lgb.train <- function(params = list(),
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, first_metric_only = isTRUE(params[["first_metric_only"]])
, verbose = verbose
)
)
Expand Down Expand Up @@ -280,13 +286,28 @@ lgb.train <- function(params = list(),
# Collection: Has validation dataset?
if (length(valids) > 0L) {

# Validation has training dataset?
if (valid_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = feval))
# Get evaluation results with passed-in functions
for (eval_function in eval_functions) {

# Validation has training dataset?
if (valid_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = eval_function))
}

eval_list <- append(eval_list, booster$eval_valid(feval = eval_function))
}

# Calling booster$eval_valid() will get
# evaluation results with the metrics in params$metric by calling LGBM_BoosterGetEval_R",
# so need to be sure that gets called, which it wouldn't be above if no functions
# were passed in
if (length(eval_functions) == 0L) {
if (valid_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = eval_function))
}
eval_list <- append(eval_list, booster$eval_valid(feval = eval_function))
}

# Has no validation dataset
eval_list <- append(eval_list, booster$eval_valid(feval = feval))
}

# Write evaluation result in environment
Expand All @@ -312,7 +333,7 @@ lgb.train <- function(params = list(),

# when using a custom eval function, the metric name is returned from the
# function, so figure it out from record_evals
if (!is.null(feval)) {
if (!is.null(eval_functions[1L])) {
first_metric <- names(booster$record_evals[[first_valid_name]])[1L]
} else {
first_metric <- booster$.__enclos_env__$private$eval_names[1L]
Expand Down
51 changes: 51 additions & 0 deletions R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,61 @@
#' and one metric. If there's more than one, will check all of them
#' except the training data. Returns the model with (best_iter + early_stopping_rounds).
#' If early stopping occurs, the model will have 'best_iter' field.
#' @param eval evaluation function(s). This can be a character vector, function, or list with a mixture of
#' strings and functions.
#'
#' \itemize{
#' \item{\bold{a. character vector}:
#' If you provide a character vector to this argument, it should contain strings with valid
#' evaluation metrics.
#' See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric}{
#' The "metric" section of the documentation}
#' for a list of valid metrics.
#' }
#' \item{\bold{b. function}:
#' You can provide a custom evaluation function. This
#' should accept the keyword arguments \code{preds} and \code{dtrain} and should return a named
#' list with three elements:
#' \itemize{
#' \item{\code{name}: A string with the name of the metric, used for printing
#' and storing results.
#' }
#' \item{\code{value}: A single number indicating the value of the metric for the
#' given predictions and true values
#' }
#' \item{
#' \code{higher_better}: A boolean indicating whether higher values indicate a better fit.
#' For example, this would be \code{FALSE} for metrics like MAE or RMSE.
#' }
#' }
#' }
#' \item{\bold{c. list}:
#' If a list is given, it should only contain character vectors and functions.
#' These should follow the requirements from the descriptions above.
#' }
#' }
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param params List of parameters
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
#' @section Early Stopping:
#'
#' "early stopping" refers to stopping the training process if the model's performance on a given
#' validation set does not improve for several consecutive iterations.
#'
#' If multiple arguments are given to \code{eval}, their order will be preserved. If you enable
#' early stopping by setting \code{early_stopping_rounds} in \code{params}, by default all
#' metrics will be considered for early stopping.
#'
#' If you want to only consider the first metric for early stopping, pass
#' \code{first_metric_only = TRUE} in \code{params}. Note that if you also specify \code{metric}
#' in \code{params}, that metric will be considered the "first" one. If you omit \code{metric},
#' a default metric will be used based on your choice for the parameter \code{obj} (keyword argument)
#' or \code{objective} (passed into \code{params}).
#' @keywords internal
NULL

Expand Down Expand Up @@ -47,6 +97,7 @@ NULL
#' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).}
#' }
#' @inheritSection lgb_shared_params Early Stopping
#' @export
lightgbm <- function(data,
label = NULL,
Expand Down
33 changes: 25 additions & 8 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ lgb.check.obj <- function(params, obj) {
}

# [description]
# make sure that "metric" is populated on params,
# and add any eval values to it
# [return]
# params, where "metric" is a list
# Take any character values from eval and store them in params$metric.
# This has to account for the fact that `eval` could be a character vector,
# a function, a list of functions, or a list with a mix of strings and
# functions
lgb.check.eval <- function(params, eval) {

if (is.null(params$metric)) {
Expand All @@ -330,13 +330,30 @@ lgb.check.eval <- function(params, eval) {
params$metric <- as.list(params$metric)
}

if (is.character(eval)) {
params$metric <- append(params$metric, eval)
# if 'eval' is a character vector or list, find the character
# elements and add them to 'metric'
if (!is.function(eval)) {
for (i in seq_along(eval)) {
element <- eval[[i]]
if (is.character(element)) {
params$metric <- append(params$metric, element)
}
}
}

if (identical(class(eval), "list")) {
params$metric <- append(params$metric, unlist(eval))
# If more than one character metric was given, then "None" should
# not be included
if (length(params$metric) > 1L) {
params$metric <- Filter(
f = function(metric) {
!(metric %in% .NO_METRIC_STRINGS())
}
, x = params$metric
)
}

# duplicate metrics should be filtered out
params$metric <- as.list(unique(unlist(params$metric)))

return(params)
}
Loading

0 comments on commit d4325c5

Please sign in to comment.