Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Add missing prediction functions to R interface #4982

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export(lgb.Dataset.create.valid)
export(lgb.Dataset.save)
export(lgb.Dataset.set.categorical)
export(lgb.Dataset.set.reference)
export(lgb.configure_fast_predict)
export(lgb.convert_with_rules)
export(lgb.cv)
export(lgb.drop_serialized)
Expand All @@ -38,6 +39,12 @@ export(saveRDS.lgb.Booster)
export(set_field)
export(slice)
import(methods)
importClassesFrom(Matrix,CsparseMatrix)
importClassesFrom(Matrix,RsparseMatrix)
importClassesFrom(Matrix,dgCMatrix)
importClassesFrom(Matrix,dgRMatrix)
importClassesFrom(Matrix,dsparseMatrix)
importClassesFrom(Matrix,dsparseVector)
importFrom(Matrix,Matrix)
importFrom(R6,R6Class)
importFrom(data.table,":=")
Expand All @@ -52,6 +59,7 @@ importFrom(graphics,barplot)
importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(methods,is)
importFrom(methods,new)
importFrom(stats,quantile)
importFrom(utils,modifyList)
importFrom(utils,read.delim)
Expand Down
190 changes: 181 additions & 9 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Booster <- R6::R6Class(
params = list(),
record_evals = list(),

fast_predict_config = list(),

# Finalize will free up the handles
finalize = function() {
.Call(
Expand Down Expand Up @@ -491,6 +493,7 @@ Booster <- R6::R6Class(
predictor <- Predictor$new(
modelfile = private$handle
, params = params
, fast_predict_config = self$fast_predict_config
)
return(
predictor$predict(
Expand All @@ -512,6 +515,57 @@ Booster <- R6::R6Class(
return(Predictor$new(modelfile = private$handle))
},

configure_fast_predict = function(csr = FALSE,
start_iteration = NULL,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
params = list()) {

self$restore_handle()
ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)

if (is.null(num_iteration)) {
num_iteration <- -1L
}
if (is.null(start_iteration)) {
start_iteration <- 0L
}

if (!csr) {
fun <- LGBM_BoosterPredictForMatSingleRowFastInit_R
} else {
fun <- LGBM_BoosterPredictForCSRSingleRowFastInit_R
}

fast_handle <- .Call(
fun
, private$handle
, ncols
, rawscore
, predleaf
, predcontrib
, start_iteration
, num_iteration
, lgb.params2str(params = params)
)

self$fast_predict_config <- list(
handle = fast_handle
, csr = as.logical(csr)
, ncols = ncols
, start_iteration = start_iteration
, num_iteration = num_iteration
, rawscore = as.logical(rawscore)
, predleaf = as.logical(predleaf)
, predcontrib = as.logical(predcontrib)
, params = params
)

return(invisible(NULL))
},

# Used for serialization
raw = NULL,

Expand Down Expand Up @@ -709,12 +763,7 @@ Booster <- R6::R6Class(
)
)

#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
#' @param object Object of class \code{lgb.Booster}
#' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
#' @name lgb_predict_shared_params
#' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
Expand All @@ -728,13 +777,38 @@ Booster <- R6::R6Class(
#' for logistic regression would result in predictions for log-odds instead of probabilities.
#' @param predleaf whether predict leaf index instead.
#' @param predcontrib return per-feature contributions for each record.
#' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case.
#' @param params a list of additional named parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#' the "Predict Parameters" section of the documentation} for a list of parameters and
#' valid values.
NULL

#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
#' @details If the model object has been configured for fast single-row CSR predictions through
#' \link{lgb.configure_fast_predict}, this function will use the prediction parameters
#' that were configured for it - as such, extra prediction parameters should not be passed
#' here, otherwise the configuration will be ignored and the slow route will be taken.
#' @inheritParams lgb_predict_shared_params
#' @param object Object of class \code{lgb.Booster}
#' @param newdata a \code{matrix} object, a \code{dgCMatrix}, a \code{dgRMatrix} object, a \code{dsparseVector} object,
#' or a character representing a path to a text file (CSV, TSV, or LibSVM).
#'
#' For sparse inputs, if predictions are only going to be made for a single row, it will be faster to
#' use CSR format, in which case the data may be passed as either a single-row CSR matrix (class
#' `dgRMatrix` from package `Matrix`) or as a sparse numeric vector (class `dsparseVector` from
#' package `Matrix`).
#'
#' If single-row predictions are going to be performed frequently, it is recommended to
#' pre-configure the model object for fast single-row sparse predictions through function
#' \link{lgb.configure_fast_predict}.
#' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case.
#'
#' If passing `predcontrib=TRUE` and the input data is sparse, this parameter will be forced
#' to `TRUE`, outputting a sparse matrix or vector of the same class as the input data.
#' @param ... ignored
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
Expand All @@ -744,6 +818,9 @@ Booster <- R6::R6Class(
#' When \code{predleaf = TRUE}, the output is a matrix object with the
#' number of columns corresponding to the number of trees.
#'
#' If using `predcontrib=TRUE` and the input data is a sparse matrix or sparse vector,
#' the output will also be a sparse matrix or vector of the same class.
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
Expand Down Expand Up @@ -803,6 +880,10 @@ predict.lgb.Booster <- function(object,
))
}

if (!reshape && predcontrib && inherits(newdata, c("dsparseMatrix", "dsparseVector"))) {
reshape <- TRUE
}

return(
object$predict(
data = newdata
Expand All @@ -818,6 +899,97 @@ predict.lgb.Booster <- function(object,
)
}

#' @title Configure Fast Single-Row Predictions
#' @description Pre-configures a LightGBM model object to produce fast single-row predictions
#' for a given input data type, prediction type, and parameters.
#' @details Calling this function multiple times with different parameters might not override
#' the previous configuration and might trigger undefined behavior.
#'
#' Any saved configuration for fast predictions might be lost after making a single-row
#' prediction of a different type than what was configured.
#'
#' In some situations, setting a fast prediction configuration for one type of prediction
#' might cause the prediction function to keep using that configuration for single-row
#' predictions even if the requested type of prediction is different from what was configured.
#'
#' The configuration does not survive de-serializations, so it has to be generated
#' anew in every R process that is going to use it (e.g. if loading a model object
#' through `readRDS`, whatever configuration was there previously will be lost).
#'
#' Requesting a different prediction type or passing parameters to \link{predict.lgb.Booster}
#' will cause it to ignore the fast-predict configuration and take the slow route instead
#' (but be aware that an existing configuration might not always be overriden by supplying
#' different parameters or prediction type, so make sure to check that the output is what
#' was expected when a prediction is to be made on a single row for something different than
#' what is configured).
#'
#' Note that, if configuring a non-default prediction type (such as leaf indices),
#' then that type must also be passed in the call to \link{predict.lgb.Booster} in
#' order for it to use the configuration. This also applies for `start_iteration`
#' and `num_iteration`, but \bold{the `params` list must be empty} in the call to `predict`.
#'
#' Predictions about feature contributions do not allow a fast route for CSR inputs,
#' and as such, this function will produce an error if passing `csr=TRUE` and
#' `predcontrib=TRUE` together.
#' @inheritParams lgb_predict_shared_params
#' @param model LighGBM model object (class \code{lgb.Booster}).
#'
#' \bold{The object will be modified in-place}.
#' @param csr Whether the prediction function is going to be called on sparse CSR inputs.
#' If `FALSE`, will be assumed that predictions are going to be called on single-row
#' regular R matrices.
#' @return The same `model` that was passed as input, as invisible, with the desired
#' configuration stored inside it and available to be used in future calls to
#' \link{predict.lgb.Booster}.
#' @examples
#' library(lightgbm)
#' data(mtcars)
#' X <- as.matrix(mtcars[, -1L])
#' y <- mtcars[, 1L]
#' dtrain <- lgb.Dataset(X, label = y, params = list(max_bin = 5L))
#' params <- list(min_data_in_leaf = 2L)
#' model <- lgb.train(
#' params = params
#' , data = dtrain
#' , obj = "regression"
#' , nrounds = 5L
#' , verbose = -1L
#' )
#' lgb.configure_fast_predict(model)
#'
#' x_single <- X[11L, , drop = FALSE]
#' predict(model, x_single)
#'
#' # Will not use it if the prediction to be made
#' # is different from what was configured
#' predict(model, x_single, predleaf = TRUE)
#' @export
lgb.configure_fast_predict <- function(model,
csr = FALSE,
start_iteration = NULL,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
params = list()) {
if (!lgb.is.Booster(x = model)) {
stop("lgb.configure_fast_predict: model should be an ", sQuote("lgb.Booster"))
}
if (csr && predcontrib) {
stop("'lgb.configure_fast_predict' does not support feature contributions for CSR data.")
}
model$configure_fast_predict(
csr
, start_iteration
, num_iteration
, rawscore
, predleaf
, predcontrib
, params
)
return(invisible(model))
}

#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
Expand Down
Loading