Skip to content

Commit

Permalink
[R-package] ensure values in params override keyword arguments to pre…
Browse files Browse the repository at this point in the history
…dict() (fixes #4670) (#5122)

* fix predict() params

* [R-package] ensure values in params override keyword arguments to predict() (fixes #4670)

* revert accidentally-introduced unrelated test changes

* Update R-package/tests/testthat/test_Predictor.R

Co-authored-by: José Morales <[email protected]>

* Update R-package/tests/testthat/test_Predictor.R

Co-authored-by: José Morales <[email protected]>

* linting

* remove nammes in shap test

* changes to tests

Co-authored-by: José Morales <[email protected]>
  • Loading branch information
jameslamb and jmoralez authored Apr 24, 2022
1 parent 3d25e37 commit 21fb16a
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 2 deletions.
31 changes: 30 additions & 1 deletion R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,34 @@ Booster <- R6::R6Class(
start_iteration <- 0L
}

# possibly override keyword arguments with parameters
#
# NOTE: this length() check minimizes the latency introduced by these checks,
# for the common case where params is empty
#
# NOTE: doing this here instead of in Predictor$predict() to keep
# Predictor$predict() as fast as possible
if (length(params) > 0L) {
params <- lgb.check.wrapper_param(
main_param_name = "predict_raw_score"
, params = params
, alternative_kwarg_value = rawscore
)
params <- lgb.check.wrapper_param(
main_param_name = "predict_leaf_index"
, params = params
, alternative_kwarg_value = predleaf
)
params <- lgb.check.wrapper_param(
main_param_name = "predict_contrib"
, params = params
, alternative_kwarg_value = predcontrib
)
rawscore <- params[["predict_raw_score"]]
predleaf <- params[["predict_leaf_index"]]
predcontrib <- params[["predict_contrib"]]
}

# Predict on new data
predictor <- Predictor$new(
modelfile = private$handle
Expand Down Expand Up @@ -730,7 +758,8 @@ Booster <- R6::R6Class(
#' @param params a list of additional named parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#' the "Predict Parameters" section of the documentation} for a list of parameters and
#' valid values.
#' valid values. Where these conflict with the values of keyword arguments to this function,
#' the values in \code{params} take precedence.
#' @param ... ignored
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#' For multiclass classification, it returns a matrix of dimensions \code{(nrows(data), num_class)}.
Expand Down
3 changes: 2 additions & 1 deletion R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

159 changes: 159 additions & 0 deletions R-package/tests/testthat/test_Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ VERBOSITY <- as.integer(
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
)

TOLERANCE <- 1e-6

library(Matrix)

test_that("Predictor$finalize() should not fail", {
Expand Down Expand Up @@ -114,6 +116,163 @@ test_that("start_iteration works correctly", {
expect_equal(pred_leaf1, pred_leaf2)
})

test_that("predict() params should override keyword argument for raw-score predictions", {
data(agaricus.train, package = "lightgbm")
X <- agaricus.train$data
y <- agaricus.train$label
bst <- lgb.train(
data = lgb.Dataset(
data = X
, label = y
, params = list(
data_seed = 708L
, min_data_in_bin = 5L
)
)
, params = list(
objective = "binary"
, min_data_in_leaf = 1L
, seed = 708L
)
, nrounds = 10L
, verbose = VERBOSITY
)

# check that the predictions from predict.lgb.Booster() really look like raw score predictions
preds_prob <- predict(bst, X)
preds_raw_s3_keyword <- predict(bst, X, rawscore = TRUE)
preds_prob_from_raw <- 1.0 / (1.0 + exp(-preds_raw_s3_keyword))
expect_equal(preds_prob, preds_prob_from_raw, tolerance = TOLERANCE)
accuracy <- sum(as.integer(preds_prob_from_raw > 0.5) == y) / length(y)
expect_equal(accuracy, 1.0)

# should get the same results from Booster$predict() method
preds_raw_r6_keyword <- bst$predict(X, rawscore = TRUE)
expect_equal(preds_raw_s3_keyword, preds_raw_r6_keyword)

# using a parameter alias of predict_raw_score should result in raw scores being returned
aliases <- .PARAMETER_ALIASES()[["predict_raw_score"]]
expect_true(length(aliases) > 1L)
for (rawscore_alias in aliases) {
params <- as.list(
stats::setNames(
object = TRUE
, nm = rawscore_alias
)
)
preds_raw_s3_param <- predict(bst, X, params = params)
preds_raw_r6_param <- bst$predict(X, params = params)
expect_equal(preds_raw_s3_keyword, preds_raw_s3_param)
expect_equal(preds_raw_s3_keyword, preds_raw_r6_param)
}
})

test_that("predict() params should override keyword argument for leaf-index predictions", {
data(mtcars)
X <- as.matrix(mtcars[, which(names(mtcars) != "mpg")])
y <- as.numeric(mtcars[, "mpg"])
bst <- lgb.train(
data = lgb.Dataset(
data = X
, label = y
, params = list(
min_data_in_bin = 1L
, data_seed = 708L
)
)
, params = list(
objective = "regression"
, min_data_in_leaf = 1L
, seed = 708L
)
, nrounds = 10L
, verbose = VERBOSITY
)

# check that predictions really look like leaf index predictions
preds_leaf_s3_keyword <- predict(bst, X, predleaf = TRUE)
expect_true(is.matrix(preds_leaf_s3_keyword))
expect_equal(dim(preds_leaf_s3_keyword), c(nrow(X), bst$current_iter()))
expect_true(min(preds_leaf_s3_keyword) >= 0L)
trees_dt <- lgb.model.dt.tree(bst)
max_leaf_by_tree_from_dt <- trees_dt[, .(idx = max(leaf_index, na.rm = TRUE)), by = tree_index]$idx
max_leaf_by_tree_from_preds <- apply(preds_leaf_s3_keyword, 2L, max, na.rm = TRUE)
expect_equal(max_leaf_by_tree_from_dt, max_leaf_by_tree_from_preds)

# should get the same results from Booster$predict() method
preds_leaf_r6_keyword <- bst$predict(X, predleaf = TRUE)
expect_equal(preds_leaf_s3_keyword, preds_leaf_r6_keyword)

# using a parameter alias of predict_leaf_index should result in leaf indices being returned
aliases <- .PARAMETER_ALIASES()[["predict_leaf_index"]]
expect_true(length(aliases) > 1L)
for (predleaf_alias in aliases) {
params <- as.list(
stats::setNames(
object = TRUE
, nm = predleaf_alias
)
)
preds_leaf_s3_param <- predict(bst, X, params = params)
preds_leaf_r6_param <- bst$predict(X, params = params)
expect_equal(preds_leaf_s3_keyword, preds_leaf_s3_param)
expect_equal(preds_leaf_s3_keyword, preds_leaf_r6_param)
}
})

test_that("predict() params should override keyword argument for feature contributions", {
data(mtcars)
X <- as.matrix(mtcars[, which(names(mtcars) != "mpg")])
y <- as.numeric(mtcars[, "mpg"])
bst <- lgb.train(
data = lgb.Dataset(
data = X
, label = y
, params = list(
min_data_in_bin = 1L
, data_seed = 708L
)
)
, params = list(
objective = "regression"
, min_data_in_leaf = 1L
, seed = 708L
)
, nrounds = 10L
, verbose = VERBOSITY
)

# check that predictions really look like feature contributions
preds_contrib_s3_keyword <- predict(bst, X, predcontrib = TRUE)
num_features <- ncol(X)
shap_base_value <- unname(preds_contrib_s3_keyword[, ncol(preds_contrib_s3_keyword)])
expect_true(is.matrix(preds_contrib_s3_keyword))
expect_equal(dim(preds_contrib_s3_keyword), c(nrow(X), num_features + 1L))
expect_equal(length(unique(shap_base_value)), 1L)
expect_equal(mean(y), shap_base_value[1L])
expect_equal(predict(bst, X), rowSums(preds_contrib_s3_keyword))

# should get the same results from Booster$predict() method
preds_contrib_r6_keyword <- bst$predict(X, predcontrib = TRUE)
expect_equal(preds_contrib_s3_keyword, preds_contrib_r6_keyword)

# using a parameter alias of predict_contrib should result in feature contributions being returned
aliases <- .PARAMETER_ALIASES()[["predict_contrib"]]
expect_true(length(aliases) > 1L)
for (predcontrib_alias in aliases) {
params <- as.list(
stats::setNames(
object = TRUE
, nm = predcontrib_alias
)
)
preds_contrib_s3_param <- predict(bst, X, params = params)
preds_contrib_r6_param <- bst$predict(X, params = params)
expect_equal(preds_contrib_s3_keyword, preds_contrib_s3_param)
expect_equal(preds_contrib_s3_keyword, preds_contrib_r6_param)
}
})

.expect_has_row_names <- function(pred, X) {
if (is.vector(pred)) {
rnames <- names(pred)
Expand Down

0 comments on commit 21fb16a

Please sign in to comment.