Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Use inplace predict #9829

Merged
merged 29 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2471933
use inplace predict
david-cortes Nov 30, 2023
b605ef7
simplify further
david-cortes Nov 30, 2023
af81946
solve merge conflicts
david-cortes Dec 2, 2023
5b8c661
solve merge conflicts
david-cortes Dec 6, 2023
080f372
enable inplace_predict also for non-training dart
david-cortes Dec 11, 2023
d1fdac0
add base_margin to predict
david-cortes Dec 11, 2023
bd62c5b
solve merge conflicts
david-cortes Dec 12, 2023
9a19e3d
add test for base_margin
david-cortes Dec 12, 2023
d07162b
add inplace prediction also for data frames
david-cortes Dec 12, 2023
dea90b0
corrections for prediction on data.frame
david-cortes Dec 13, 2023
9347a41
linter
david-cortes Dec 17, 2023
248204c
solve merge conflicts
david-cortes Dec 27, 2023
f6cd2db
remove unused variable
david-cortes Dec 27, 2023
d824ff4
spacing
david-cortes Dec 27, 2023
50811c3
Merge remote-tracking branch 'upstream/master' into inplace_predict
hcho3 Jan 2, 2024
7a775d0
solve merge conflict
david-cortes Jan 8, 2024
2b32b4f
solve conflicts
david-cortes Jan 11, 2024
8584093
update roxygen
david-cortes Jan 11, 2024
5fec492
solve conflicts
david-cortes Jan 20, 2024
5b045ba
solve merge conflicts
david-cortes Jan 30, 2024
4f07239
use anonymous namespace instead of statics
david-cortes Jan 30, 2024
218c3db
linter
david-cortes Jan 30, 2024
b999c2a
solve merge conflicts
david-cortes Jan 31, 2024
afde244
Merge branch 'master' into inplace_predict
david-cortes Feb 7, 2024
e03527c
add inplace predict also for sparse vectors
david-cortes Feb 7, 2024
16cb633
update docs
david-cortes Feb 7, 2024
97e7e76
Merge branch 'master' into inplace_predict
david-cortes Feb 22, 2024
b7ca893
update wording
david-cortes Feb 22, 2024
9609b61
move warning to error
david-cortes Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 115 additions & 18 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,44 @@ xgb.get.handle <- function(object) {

#' Predict method for XGBoost model
#'
#' Predicted values based on either xgboost model or model handle object.
#' Predict values on data based on xgboost model.
#'
#' @param object Object of class `xgb.Booster`.
#' @param newdata Takes `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
#' @param newdata Takes `data.frame`, `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
#' local data file, or `xgb.DMatrix`.
#' For single-row predictions on sparse data, it is recommended to use the CSR format.
#' If passing a sparse vector, it will take it as a row vector.
#' @param missing Only used when input is a dense matrix. Pick a float value that represents
#' missing values in data (e.g., 0 or some other extreme value).
#'
#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
#' a sparse vector, it will take it as a row vector.
#'
#' Note that, for repeated predictions, one might want to create a DMatrix to pass here instead
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
#' of passing R types like matrices or data frames, as predictions will be faster on DMatrix.
#'
#' If `newdata` is a `data.frame`, be aware that:\itemize{
#' \item Columns will be converted to numeric if they aren't already, which could potentially make
#' the operation slower than in an equivalent `matrix` object.
#' \item The order of the columns must match with that of the data from which the model was fitted
#' (i.e. columns will not be referenced by their names, just by their order in the data).
#' \item If the model was fitted to data with categorical columns, these columns must be of
#' `factor` type here, and must use the same encoding (i.e. have the same levels).
#' \item If `newdata` contains any `factor` columns, they will be converted to base-0
#' encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
#' under a column which during training had a different type.
#' }
#' @param missing Float value that represents missing values in data (e.g., 0 or some other extreme value).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually use nan as missing value, since XGBoost can handle 0 just as any other real number.

This is also a headache for us since many sparse matrix implementations assume 0 as missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, default value is R's NA, which gets translated to either NAN or R_NaInt depending on the input type.

#'
#' This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
#' this as an argument to the DMatrix constructor instead.
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
#' logistic regression would return log-odds instead of probabilities.
#' @param predleaf Whether to predict pre-tree leaf indices.
#' @param predleaf Whether to predict per-tree leaf indices.
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
#' or `predinteraction` is `TRUE`.
#' @param training Whether the predictions are used for training. For dart booster,
#' @param training Whether the prediction result is used for training. For dart booster,
#' training predicting will perform dropout.
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
Expand All @@ -111,6 +129,12 @@ xgb.get.handle <- function(object) {
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
#' type and shape of predictions are invariant to the model type.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
#' be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
#' an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
#'
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
#' match (only applicable when both `object` and `newdata` have feature names).
#'
Expand Down Expand Up @@ -287,16 +311,77 @@ xgb.get.handle <- function(object) {
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
validate_features = FALSE, ...) {
validate_features = FALSE, base_margin = NULL, ...) {
if (validate_features) {
newdata <- validate.features(object, newdata)
}
if (!inherits(newdata, "xgb.DMatrix")) {
is_dmatrix <- inherits(newdata, "xgb.DMatrix")
if (is_dmatrix && !is.null(base_margin)) {
warning("'base_margin' is ignored when passing 'xgb.DMatrix' as input.")
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
}

use_as_df <- FALSE
use_as_dense_matrix <- FALSE
use_as_csr_matrix <- FALSE
n_row <- NULL
if (!is_dmatrix) {

inplace_predict_supported <- !predcontrib && !predinteraction && !predleaf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to support leaf in the near future.

if (inplace_predict_supported) {
booster_type <- xgb.booster_type(object)
if (booster_type == "gblinear" || (booster_type == "dart" && training)) {
inplace_predict_supported <- FALSE
}
}
if (inplace_predict_supported) {

if (is.matrix(newdata)) {
use_as_dense_matrix <- TRUE
} else if (is.data.frame(newdata)) {
# note: since here it turns it into a non-data-frame list,
# needs to keep track of the number of rows it had for later
n_row <- nrow(newdata)
newdata <- lapply(
newdata,
function(x) {
if (is.factor(x)) {
return(as.numeric(x) - 1)
} else {
return(as.numeric(x))
}
}
)
use_as_df <- TRUE
} else if (inherits(newdata, "dgRMatrix")) {
use_as_csr_matrix <- TRUE
csr_data <- list(newdata@p, newdata@j, newdata@x, ncol(newdata))
} else if (inherits(newdata, "dsparseVector")) {
use_as_csr_matrix <- TRUE
n_row <- 1L
i <- newdata@i - 1L
if (storage.mode(i) != "integer") {
storage.mode(i) <- "integer"
}
csr_data <- list(c(0L, length(i)), i, newdata@x, length(newdata))
}

}

} # if (!is_dmatrix)

if (!is_dmatrix && !use_as_dense_matrix && !use_as_csr_matrix && !use_as_df) {
nthread <- xgb.nthread(object)
newdata <- xgb.DMatrix(
newdata,
missing = missing, nthread = NVL(nthread, -1)
missing = missing,
base_margin = base_margin,
nthread = NVL(nthread, -1)
)
is_dmatrix <- TRUE
}

if (is.null(n_row)) {
n_row <- nrow(newdata)
}


Expand Down Expand Up @@ -354,18 +439,30 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
args$type <- set_type(6)
}

predts <- .Call(
XGBoosterPredictFromDMatrix_R,
xgb.get.handle(object),
newdata,
jsonlite::toJSON(args, auto_unbox = TRUE)
)
json_conf <- jsonlite::toJSON(args, auto_unbox = TRUE)
if (is_dmatrix) {
predts <- .Call(
XGBoosterPredictFromDMatrix_R, xgb.get.handle(object), newdata, json_conf
)
} else if (use_as_dense_matrix) {
predts <- .Call(
XGBoosterPredictFromDense_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
} else if (use_as_csr_matrix) {
predts <- .Call(
XGBoosterPredictFromCSR_R, xgb.get.handle(object), csr_data, missing, json_conf, base_margin
)
} else if (use_as_df) {
predts <- .Call(
XGBoosterPredictFromColumnar_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
}

names(predts) <- c("shape", "results")
shape <- predts$shape
arr <- predts$results

n_ret <- length(arr)
n_row <- nrow(newdata)
if (n_row != shape[1]) {
stop("Incorrect predict shape.")
}
Expand Down
44 changes: 36 additions & 8 deletions R-package/man/predict.xgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions R-package/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromDense_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromColumnar_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
Expand Down Expand Up @@ -96,6 +99,9 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
{"XGBoosterPredictFromDense_R", (DL_FUNC) &XGBoosterPredictFromDense_R, 5},
{"XGBoosterPredictFromCSR_R", (DL_FUNC) &XGBoosterPredictFromCSR_R, 5},
{"XGBoosterPredictFromColumnar_R", (DL_FUNC) &XGBoosterPredictFromColumnar_R, 5},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
Expand Down
Loading
Loading