diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index ac9f2404b606..7e00577c7611 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -713,12 +713,15 @@ Booster <- R6::R6Class( #' @param object Object of class \code{lgb.Booster} #' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or #' a character representing a path to a text file (CSV, TSV, or LibSVM) -#' @param start_iteration int or None, optional (default=None) +#' @param start_iteration int or `NULL`, optional (default=`NULL`) #' Start index of the iteration to predict. -#' If None or <= 0, starts from the first iteration. -#' @param num_iteration int or None, optional (default=None) +#' If `NULL` or <= 0, starts from the first iteration. +#' +#' If using `index1=FALSE`, it will be assumed that the numeration starts +#' at zero (e.g. passing '2' will mean starting from the 3rd round). +#' @param num_iteration int or `NULL`, optional (default=`NULL`) #' Limit number of iterations in the prediction. -#' If None, if the best iteration exists and start_iteration is None or <= 0, the +#' If `NULL`, if the best iteration exists and start_iteration is `NULL` or <= 0, the #' best iteration is used; otherwise, all iterations from start_iteration are used. #' If <= 0, all iterations from start_iteration are used (no limits). #' @param rawscore whether the prediction should be returned in the for of original untransformed @@ -731,6 +734,10 @@ Booster <- R6::R6Class( #' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{ #' the "Predict Parameters" section of the documentation} for a list of parameters and #' valid values. +#' @param index1 When passing argument `start_iteration` and/or when producing outputs that correspond +#' to some numeration (such as leaf indices), whether to take these inputs as and/or make +#' these outputs have a numeration starting at 1 or at 0. Note that the underlying lightgbm +#' core library uses zero-based numeration, thus `index1=FALSE` will be slightly faster. #' @param ... ignored #' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}. #' For multiclass classification, it returns a matrix of dimensions \code{(nrows(data), num_class)}. @@ -781,6 +788,7 @@ predict.lgb.Booster <- function(object, predcontrib = FALSE, header = FALSE, params = list(), + index1 = TRUE, ...) { if (!lgb.is.Booster(x = object)) { @@ -799,18 +807,25 @@ predict.lgb.Booster <- function(object, )) } - return( - object$predict( - data = newdata - , start_iteration = start_iteration - , num_iteration = num_iteration - , rawscore = rawscore - , predleaf = predleaf - , predcontrib = predcontrib - , header = header - , params = params - ) + if (!is.null(start_iteration) && start_iteration > 0L && index1) { + start_iteration <- start_iteration - 1L + } + + pred <- object$predict( + data = newdata + , start_iteration = start_iteration + , num_iteration = num_iteration + , rawscore = rawscore + , predleaf = predleaf + , predcontrib = predcontrib + , header = header + , params = params ) + + if (predleaf && index1) { + pred <- pred + 1.0 + } + return(pred) } #' @name print.lgb.Booster diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index d4ddfe0ff668..94a167f357f7 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -14,6 +14,7 @@ predcontrib = FALSE, header = FALSE, params = list(), + index1 = TRUE, ... ) } @@ -23,13 +24,16 @@ \item{newdata}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a path to a text file (CSV, TSV, or LibSVM)} -\item{start_iteration}{int or None, optional (default=None) -Start index of the iteration to predict. -If None or <= 0, starts from the first iteration.} +\item{start_iteration}{int or `NULL`, optional (default=`NULL`) + Start index of the iteration to predict. + If `NULL` or <= 0, starts from the first iteration. -\item{num_iteration}{int or None, optional (default=None) + If using `index1=FALSE`, it will be assumed that the numeration starts + at zero (e.g. passing '2' will mean starting from the 3rd round).} + +\item{num_iteration}{int or `NULL`, optional (default=`NULL`) Limit number of iterations in the prediction. -If None, if the best iteration exists and start_iteration is None or <= 0, the +If `NULL`, if the best iteration exists and start_iteration is `NULL` or <= 0, the best iteration is used; otherwise, all iterations from start_iteration are used. If <= 0, all iterations from start_iteration are used (no limits).} @@ -48,6 +52,11 @@ for logistic regression would result in predictions for log-odds instead of prob the "Predict Parameters" section of the documentation} for a list of parameters and valid values.} +\item{index1}{When passing argument `start_iteration` and/or when producing outputs that correspond +to some numeration (such as leaf indices), whether to take these inputs as and/or make +these outputs have a numeration starting at 1 or at 0. Note that the underlying lightgbm +core library uses zero-based numeration, thus `index1=FALSE` will be slightly faster.} + \item{...}{ignored} } \value{ diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 5d3f172b9f6e..e1e6d299491d 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -95,6 +95,7 @@ test_that("start_iteration works correctly", { , start_iteration = start_iter , num_iteration = n_iter , rawscore = TRUE + , index1 = FALSE ) inc_pred_contrib <- bst$predict(test$data , start_iteration = start_iter @@ -107,8 +108,20 @@ test_that("start_iteration works correctly", { expect_equal(pred2, pred1) expect_equal(pred_contrib2, pred_contrib1) - pred_leaf1 <- predict(bst, test$data, predleaf = TRUE) - pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE) + pred_leaf1 <- predict( + bst + , test$data + , predleaf = TRUE + , index1 = FALSE + ) + pred_leaf2 <- predict( + bst + , test$data + , start_iteration = 0L + , num_iteration = end_iter + 1L + , predleaf = TRUE + , index1 = FALSE + ) expect_equal(pred_leaf1, pred_leaf2) })