Skip to content

Commit

Permalink
[R-package] fix issue where early stopping thinks higher MAPE is desi…
Browse files Browse the repository at this point in the history
…rable (fixes microsoft#3099) (microsoft#3101)

* [R-package] fix issue where early stopping thinks higher MAPE is desirable (fixes microsoft#3099)

* fix linting

* only use main metrics

* fix tests
  • Loading branch information
jameslamb authored and ChipKerchner committed Jun 10, 2020
1 parent 54a0b7a commit 2a06f0e
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 5 deletions.
6 changes: 3 additions & 3 deletions R-package/R/callback.R
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
}

# Maximization or minimization task
# Internally treat everything as a maximization task
factor_to_bigger_better <<- rep.int(1.0, eval_len)
best_iter <<- rep.int(-1L, eval_len)
best_score <<- rep.int(-Inf, eval_len)
Expand All @@ -305,8 +305,8 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Prepend message
best_msg <<- c(best_msg, "")

# Check if maximization or minimization
if (!env$eval_list[[i]]$higher_better) {
# Internally treat everything as a maximization task
if (!isTRUE(env$eval_list[[i]]$higher_better)) {
factor_to_bigger_better[i] <<- -1.0
}

Expand Down
6 changes: 5 additions & 1 deletion R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,11 @@ Booster <- R6::R6Class(
# Parse and store privately names
names <- strsplit(names, "\t")[[1L]]
private$eval_names <- names
private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc", names)

# some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
# ndcg metric evaluated at the first "query result" in learning-to-rank
metric_names <- gsub("@.*", "", names)
private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]

}

Expand Down
34 changes: 34 additions & 0 deletions R-package/R/metrics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# [description] List of metrics known to LightGBM. The most up to date list can be found
# at https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric-parameters
#
# [return] A named logical vector, where each key is a metric name and each value is a boolean.
# TRUE if higher values of the metric are desirable, FALSE if lower values are desirable.
# Note that only the 'main' metrics are stored here, not aliases, since only the 'main' metrics
# are returned from the C++ side. For example, if you use `metric = "mse"` in your code,
# the metric name `"l2"` will be returned.
.METRICS_HIGHER_BETTER <- function() {
return(c(
"l1" = FALSE
, "l2" = FALSE
, "mape" = FALSE
, "rmse" = FALSE
, "quantile" = FALSE
, "huber" = FALSE
, "fair" = FALSE
, "poisson" = FALSE
, "gamma" = FALSE
, "gamma_deviance" = FALSE
, "tweedie" = FALSE
, "ndcg" = TRUE
, "map" = TRUE
, "auc" = TRUE
, "binary_logloss" = FALSE
, "binary_error" = FALSE
, "auc_mu" = TRUE
, "multi_logloss" = FALSE
, "multi_error" = FALSE
, "cross_entropy" = FALSE
, "cross_entropy_lambda" = FALSE
, "kullback_leibler" = FALSE
))
}
133 changes: 132 additions & 1 deletion R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ test_that("lightgbm() performs evaluation on validation sets if they are provide
, num_leaves = 5L
, nrounds = nrounds
, objective = "binary"
, metric = "binary_error"
, metric = c(
"binary_error"
, "auc"
)
, valids = list(
"valid1" = dvalid1
, "valid2" = dvalid2
Expand Down Expand Up @@ -534,6 +537,73 @@ test_that("lgb.train() works with early stopping for classification", {

})

test_that("lgb.train() works with early stopping for classification with a metric that should be maximized", {
set.seed(708L)
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
dvalid <- lgb.Dataset(
data = test$data
, label = test$label
)
nrounds <- 10L

#############################
# train with early stopping #
#############################
early_stopping_rounds <- 5L
# the harsh max_depth guarantees that AUC improves over at least the first few iterations
bst_auc <- lgb.train(
params = list(
objective = "binary"
, metric = "auc"
, max_depth = 3L
, early_stopping_rounds = early_stopping_rounds
)
, data = dtrain
, nrounds = nrounds
, valids = list(
"valid1" = dvalid
)
)
bst_binary_error <- lgb.train(
params = list(
objective = "binary"
, metric = "binary_error"
, max_depth = 3L
, early_stopping_rounds = early_stopping_rounds
)
, data = dtrain
, nrounds = nrounds
, valids = list(
"valid1" = dvalid
)
)

# early stopping should have been hit for binary_error (higher_better = FALSE)
eval_info <- bst_binary_error$.__enclos_env__$private$get_eval_info()
expect_identical(eval_info, "binary_error")
expect_identical(
unname(bst_binary_error$.__enclos_env__$private$higher_better_inner_eval)
, FALSE
)
expect_identical(bst_binary_error$best_iter, 1L)
expect_identical(bst_binary_error$current_iter(), early_stopping_rounds + 1L)
expect_true(abs(bst_binary_error$best_score - 0.01613904) < TOLERANCE)

# early stopping should not have been hit for AUC (higher_better = TRUE)
eval_info <- bst_auc$.__enclos_env__$private$get_eval_info()
expect_identical(eval_info, "auc")
expect_identical(
unname(bst_auc$.__enclos_env__$private$higher_better_inner_eval)
, TRUE
)
expect_identical(bst_auc$best_iter, 9L)
expect_identical(bst_auc$current_iter(), nrounds)
expect_true(abs(bst_auc$best_score - 0.9999969) < TOLERANCE)
})

test_that("lgb.train() works with early stopping for regression", {
set.seed(708L)
trainDF <- data.frame(
Expand Down Expand Up @@ -604,6 +674,67 @@ test_that("lgb.train() works with early stopping for regression", {
)
})

test_that("lgb.train() works with early stopping for regression with a metric that should be minimized", {
set.seed(708L)
trainDF <- data.frame(
"feat1" = rep(c(10.0, 100.0), 500L)
, "target" = rep(c(-50.0, 50.0), 500L)
)
validDF <- data.frame(
"feat1" = rep(50.0, 4L)
, "target" = rep(50.0, 4L)
)
dtrain <- lgb.Dataset(
data = as.matrix(trainDF[["feat1"]], drop = FALSE)
, label = trainDF[["target"]]
)
dvalid <- lgb.Dataset(
data = as.matrix(validDF[["feat1"]], drop = FALSE)
, label = validDF[["target"]]
)
nrounds <- 10L

#############################
# train with early stopping #
#############################
early_stopping_rounds <- 5L
bst <- lgb.train(
params = list(
objective = "regression"
, metric = c(
"mape"
, "rmse"
, "mae"
)
, min_data_in_bin = 5L
, early_stopping_rounds = early_stopping_rounds
)
, data = dtrain
, nrounds = nrounds
, valids = list(
"valid1" = dvalid
)
)

# the best model should be from the first iteration, and only 6 rounds
# should have happened (1 with improvement, 5 consecutive with no improvement)
expect_equal(bst$best_score, 1.1)
expect_equal(bst$best_iter, 1L)
expect_equal(
length(bst$record_evals[["valid1"]][["mape"]][["eval"]])
, early_stopping_rounds + 1L
)

# Booster should understand thatt all three of these metrics should be minimized
eval_info <- bst$.__enclos_env__$private$get_eval_info()
expect_identical(eval_info, c("mape", "rmse", "l1"))
expect_identical(
unname(bst$.__enclos_env__$private$higher_better_inner_eval)
, rep(FALSE, 3L)
)
})


test_that("lgb.train() supports non-ASCII feature names", {
testthat::skip("UTF-8 feature names are not fully supported in the R package")
dtrain <- lgb.Dataset(
Expand Down
12 changes: 12 additions & 0 deletions R-package/tests/testthat/test_metrics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
context(".METRICS_HIGHER_BETTER()")

test_that(".METRICS_HIGHER_BETTTER() should be well formed", {
metrics <- .METRICS_HIGHER_BETTER()
metric_names <- names(.METRICS_HIGHER_BETTER())
# should be a logical vector
expect_true(is.logical(metrics))
# no metrics should be repeated
expect_true(length(unique(metric_names)) == length(metrics))
# should not be any NAs
expect_false(any(is.na(metrics)))
})

0 comments on commit 2a06f0e

Please sign in to comment.