Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Add sparse feature contribution predictions #5108

Merged
merged 23 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
508fb1b
add predcontrib for sparse inputs
david-cortes Mar 30, 2022
557a005
register newly-added function
david-cortes Mar 31, 2022
a740dc8
solve merge conflicts
david-cortes Apr 3, 2022
631c935
comments
david-cortes Apr 3, 2022
82900f6
correct wrong types in test
david-cortes Apr 3, 2022
6fd0868
forcibly take transpose function from Matrix
david-cortes Apr 3, 2022
398ddb9
keep row names, test comparison to dense inputs
david-cortes Apr 4, 2022
7d53353
workaround for passing test while PR for row names is not merged
david-cortes Apr 4, 2022
19b706f
solve merge conflict
david-cortes Apr 5, 2022
8b016b8
Update R-package/R/lgb.Predictor.R
david-cortes Apr 10, 2022
df49aaa
Update R-package/R/lgb.Predictor.R
david-cortes Apr 10, 2022
5df97de
Update R-package/R/lgb.Predictor.R
david-cortes Apr 10, 2022
6eeddc4
proper handling of integer overflow
david-cortes Apr 10, 2022
e04c6f4
add test for CSR contrib row names
david-cortes Apr 14, 2022
4ac5a3b
solve merge conflicts
david-cortes May 14, 2022
e38a6d7
add more tests for predict(<sparse>, predcontrib=TRUE)
david-cortes May 14, 2022
5b65fd1
make linter happy
david-cortes May 14, 2022
c00a8ff
linter
david-cortes May 14, 2022
d905910
linter
david-cortes May 14, 2022
edaac4e
check error messages for bad input shapes
david-cortes May 26, 2022
ab86ad3
fix regex
david-cortes May 26, 2022
005707d
Merge github.com:microsoft/lightgbm into Rcsr1
david-cortes Jun 16, 2022
3f8467a
hard-coded number of columns in regex for tests
david-cortes Jun 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ export(saveRDS.lgb.Booster)
export(set_field)
export(slice)
import(methods)
importClassesFrom(Matrix,dgCMatrix)
importClassesFrom(Matrix,dgRMatrix)
importClassesFrom(Matrix,dsparseMatrix)
importClassesFrom(Matrix,dsparseVector)
importFrom(Matrix,Matrix)
importFrom(R6,R6Class)
importFrom(data.table,":=")
Expand All @@ -51,6 +55,7 @@ importFrom(graphics,barplot)
importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(methods,is)
importFrom(methods,new)
importFrom(parallel,detectCores)
importFrom(stats,quantile)
importFrom(utils,modifyList)
Expand Down
110 changes: 109 additions & 1 deletion R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#' @importFrom methods is
#' @importFrom methods is new
#' @importClassesFrom Matrix dsparseMatrix dsparseVector dgCMatrix dgRMatrix
#' @importFrom R6 R6Class
#' @importFrom utils read.delim
Predictor <- R6::R6Class(
Expand Down Expand Up @@ -126,6 +127,113 @@ Predictor <- R6::R6Class(
num_row <- nrow(preds)
preds <- as.vector(t(preds))

} else if (predcontrib && inherits(data, c("dsparseMatrix", "dsparseVector"))) {

ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)
ncols_out <- integer(1L)
.Call(LGBM_BoosterGetNumClasses_R, private$handle, ncols_out)
ncols_out <- (ncols + 1L) * max(ncols_out, 1L)
if (is.na(ncols_out)) {
ncols_out <- as.numeric(ncols + 1L) * as.numeric(max(ncols_out, 1L))
}
if (!inherits(data, "dsparseVector") && ncols_out > .Machine$integer.max) {
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
stop("Resulting matrix of feature contributions is too large for R to handle.")
}

if (inherits(data, "dsparseVector")) {

if (length(data) > ncols) {
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
, ncols
, length(data)))
}
res <- .Call(
LGBM_BoosterPredictSparseOutput_R
, private$handle
, c(0L, as.integer(length(data@x)))
, data@i - 1L
, data@x
, TRUE
, 1L
, ncols
, start_iteration
, num_iteration
, private$params
)
out <- methods::new("dsparseVector")
out@i <- res$indices + 1L
out@x <- res$data
out@length <- ncols_out
return(out)

} else if (inherits(data, "dgRMatrix")) {

if (ncol(data) > ncols) {
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
, ncols
, ncol(data)))
}
res <- .Call(
LGBM_BoosterPredictSparseOutput_R
, private$handle
, data@p
, data@j
, data@x
, TRUE
, nrow(data)
, ncols
, start_iteration
, num_iteration
, private$params
)
out <- methods::new("dgRMatrix")
out@p <- res$indptr
out@j <- res$indices
out@x <- res$data
out@Dim <- as.integer(c(nrow(data), ncols_out))

} else if (inherits(data, "dgCMatrix")) {

if (ncol(data) != ncols) {
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
, ncols
, ncol(data)))
}
res <- .Call(
LGBM_BoosterPredictSparseOutput_R
, private$handle
, data@p
, data@i
, data@x
, FALSE
, nrow(data)
, ncols
, start_iteration
, num_iteration
, private$params
)
out <- methods::new("dgCMatrix")
out@p <- res$indptr
out@i <- res$indices
out@x <- res$data
out@Dim <- as.integer(c(nrow(data), length(res$indptr) - 1L))

} else {

stop(sprintf("Predictions on sparse inputs are only allowed for '%s', '%s', '%s' - got: %s"
, "dsparseVector"
, "dgRMatrix"
, "dgCMatrix"
, paste(class(data)
, collapse = ", ")))

}

if (NROW(row.names(data))) {
out@Dimnames[[1L]] <- row.names(data)
}
Comment on lines +232 to +234
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please test this behavior, by adding these new combinations of predcontrib = TRUE + new {Matrix} classes to the tests from #4977?

# sparse matrix with row names

# sparse matrix without row names

Every PR adding new behavior to the package should include tests on that behavior, to catch unexpected issues with the implementation and to prevent future development from accidentally breaking that behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are already tested in the tests from the PR for row names.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not. The links in my comment point to tests on "CsparseMatrix" objects, but if you click them you won't see tests on the types referenced in this PR: "dsparseMatrix", "dsparseVector", "dgRMatrix", "dgCMatrix".

Is there something I've misunderstood about the relationship between these classes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not. The links in my comment point to tests on "CsparseMatrix" objects, but if you click them you won't see tests on the types referenced in this PR: "dsparseMatrix", "dsparseVector", "dgRMatrix", "dgCMatrix".

Is there something I've misunderstood about the relationship between these classes?

There's a class hierarchy...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please tell us specifically what you mean by "there's a class hierarchy", and why it means that you don't want to add the tests I'm asking you to add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dgCMatrix is a subclass of CsparseMatrix, which is a subclass of sparseMatrix, and so on. Classes like dsparseMatrix are abstract.

Copy link
Collaborator

@jameslamb jameslamb Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that. I'm still struggling to understand how that means that the tests I'm asking for shouldn't be added.

Consider the case added in this PR beginning with } else if (inherits(data, "dgRMatrix")) {. It doesn't contain a return() statement, so that at the end of the if - else if block, the "possibly add row names" code (the line this comment thread is on) will run.

If someone were to add return(out) on line 193, that "possibly add row names" code wouldn't be reached. I think it's desirable for a test to fail in that case, to inform us that adding that return statement is a breaking change that causes row names to not be set on the predictions.

The tests I linked test a regular dense R matrix and a CsparseMatrix. As the example below shows, that means they don't currently cover the cases where the input to predict() is a "dgRMatrix" or a "dsparseVector".

library(Matrix)

# the first batch of test cases use a regular R dense matrix
X <- matrix(rnorm(100), ncol = 4)
inherits(X, "dgRMatrix")       # FALSE
inherits(X, "dsparseMatrix")  # FALSE
inherits(X, "dsparseVector")  # FALSE
inherits(X, "dgCMatrix")       # FALSE

# the second batch of test cases converts that to a CsparseMatrix
Xcsc <- as(X, "CsparseMatrix")
inherits(Xcsc, "dgRMatrix")       # FALSE
inherits(Xcsc, "dsparseMatrix")  # TRUE
inherits(Xcsc, "dsparseVector")  # FALSE
inherits(Xcsc, "dgCMatrix")       # TRUE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, if you really want to have tests for everything, I've added a test for CSR matrices. A vector representing a single row cannot have row names so I left that out of tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you really want to have tests for everything

Thank you, yes I do. The more of this project's behaviors are reflected in tests, the less likely it is that future changes silently break that behavior. This project is too large for all of these concerns to just be kept in maintainers' heads and enforced through PR comments.

return(out)

} else {

# Not a file, we need to predict from R object
Expand Down
87 changes: 87 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ SEXP wrapped_R_raw(void *len) {
return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_R_int(void *len) {
return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_R_real(void *len) {
return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_Rf_mkChar(void *txt) {
return Rf_mkChar(reinterpret_cast<char*>(txt));
}
Expand All @@ -84,6 +92,14 @@ SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}
Expand Down Expand Up @@ -851,6 +867,76 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
R_API_END();
}

struct SparseOutputPointers {
void* indptr;
int32_t* indices;
void* data;
int indptr_type;
int data_type;
SparseOutputPointers(void* indptr, int32_t* indices, void* data)
: indptr(indptr), indices(indices), data(data) {}
};

void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
delete ptr;
}

SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
SEXP indptr,
SEXP indices,
SEXP data,
SEXP is_csr,
SEXP nrows,
SEXP ncols,
SEXP start_iteration,
SEXP num_iteration,
SEXP parameter) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* out_names[] = {"indptr", "indices", "data", ""};
SEXP out = PROTECT(Rf_mkNamed(VECSXP, out_names));
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));

int64_t out_len[2];
void *out_indptr;
int32_t *out_indices;
void *out_data;

CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
REAL(data), C_API_DTYPE_FLOAT64,
Rf_xlength(indptr), Rf_xlength(data),
Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
parameter_ptr,
Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
out_len, &out_indptr, &out_indices, &out_data));

std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
new SparseOutputPointers(
out_indptr,
out_indices,
out_data),
&delete_SparseOutputPointers
};

SEXP out_indptr_R = safe_R_int(out_len[1], &cont_token);
SET_VECTOR_ELT(out, 0, out_indptr_R);
SEXP out_indices_R = safe_R_int(out_len[0], &cont_token);
SET_VECTOR_ELT(out, 1, out_indices_R);
SEXP out_data_R = safe_R_real(out_len[0], &cont_token);
SET_VECTOR_ELT(out, 2, out_data_R);
std::memcpy(INTEGER(out_indptr_R), out_indptr, out_len[1]*sizeof(int));
std::memcpy(INTEGER(out_indices_R), out_indices, out_len[0]*sizeof(int));
std::memcpy(REAL(out_data_R), out_data, out_len[0]*sizeof(double));

UNPROTECT(3);
return out;
R_API_END();
}

SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
Expand Down Expand Up @@ -975,6 +1061,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterPredictSparseOutput_R", (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R, 10},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
Expand Down
29 changes: 29 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,35 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R(
SEXP out_result
);

/*!
* \brief make feature contribution prediction for a new Dataset
* \param handle Booster handle
* \param indptr array with the index pointer of the data in CSR or CSC format
* \param indices array with the non-zero indices of the data in CSR or CSC format
* \param data array with the non-zero values of the data in CSR or CSC format
* \param is_csr whether the input data is in CSR format or not (pass FALSE for CSC)
* \param nrows number of rows in the data
* \param ncols number of columns in the data
* \param start_iteration Start index of the iteration to predict
* \param num_iteration number of iteration for prediction, <= 0 means no limit
* \param parameter additional parameters
* \return An R list with entries "indptr", "indices", "data", constituting the
* feature contributions in sparse format, in the same storage order as
* the input data.
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictSparseOutput_R(
SEXP handle,
SEXP indptr,
SEXP indices,
SEXP data,
SEXP is_csr,
SEXP nrows,
SEXP ncols,
SEXP start_iteration,
SEXP num_iteration,
SEXP parameter
);

/*!
* \brief save model into file
* \param handle Booster handle
Expand Down
82 changes: 82 additions & 0 deletions R-package/tests/testthat/test_Predictor.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
library(Matrix)

VERBOSITY <- as.integer(
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
)
Expand Down Expand Up @@ -116,6 +118,84 @@ test_that("start_iteration works correctly", {
expect_equal(pred_leaf1, pred_leaf2)
})

test_that("Feature contributions from sparse inputs produce sparse outputs", {
data(mtcars)
X <- as.matrix(mtcars[, -1L])
y <- as.numeric(mtcars[, 1L])
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "regression"
, nrounds = 5L
, verbose = VERBOSITY
, params = list(min_data_in_leaf = 5L)
)

pred_dense <- predict(bst, X, predcontrib = TRUE)

Xcsc <- as(X, "CsparseMatrix")
pred_csc <- predict(bst, Xcsc, predcontrib = TRUE)
expect_s4_class(pred_csc, "dgCMatrix")
expect_equal(unname(pred_dense), unname(as.matrix(pred_csc)))

Xcsr <- as(X, "RsparseMatrix")
pred_csr <- predict(bst, Xcsr, predcontrib = TRUE)
expect_s4_class(pred_csr, "dgRMatrix")
expect_equal(as(pred_csr, "CsparseMatrix"), pred_csc)

Xspv <- as(X[1L, , drop = FALSE], "sparseVector")
pred_spv <- predict(bst, Xspv, predcontrib = TRUE)
expect_s4_class(pred_spv, "dsparseVector")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beyond just testing the type of the returned objects, can you also please add assertions that the predicted values are the same for all of these cases, and that they're the same as those predicted for a regular R matrix?

Those .Call() calls involve passing a lot of positional arguments with similar values, so such assertions would give us greater confidence that this is working correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought missing data was handled the same way as xgboost, which means predictions for sparse outputs should be different from those of dense inputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the sparse and dense data structures here are just different representations in memory of the exact same matrices, and that specialized methods for them in LightGBM are just intended to allow that sparse data to stay sparse throughout training + scoring.

And I believe that's not directly related to the handling of missing data (which is described in more detail in the discussion at #2921 (comment) and at https://lightgbm.readthedocs.io/en/latest/Advanced-Topics.html?highlight=missing#missing-value-handle).

Consider the following example:

library(lightgbm)
library(Matrix)

set.seed(708L)

data("EuStockMarkets")

stockDF <- as.data.frame(EuStockMarkets)
feature_names <- c("SMI", "CAC", "FTSE")
target_name <- "DAX"

# randomly set a portion of each feature to NA or 0
for (col_name in feature_names) {
    stockDF[
        sample(
            x = seq_len(nrow(stockDF))
            , size = as.integer(0.01 * nrow(stockDF))
            , replace = FALSE
        )
        , col_name
    ] <- NA_real_
    stockDF[
        sample(
            x = seq_len(nrow(stockDF))
            , size = as.integer(0.01 * nrow(stockDF))
            , replace = FALSE
        )
        , col_name
    ] <- 0.0
}

X_mat <- data.matrix(stockDF[, feature_names])
y <- stockDF[[target_name]]
X_dgCMatrix <- as(X_mat, "dgCMatrix")

bst_mat <- lightgbm::lightgbm(
    data = X_mat
    , label = y
    , objective = "regression"
    , nrounds = 10L
)

bst_dgCMatrix <- lightgbm::lightgbm(
    data = X_dgCMatrix
    , label = y
    , objective = "regression"
    , nrounds = 10L
)


# predicted values don't depend on input type from training time or the type of newdata
preds_mat_mat <- predict(bst_mat, X_mat)
preds_dgCMatrix_mat <- predict(bst_mat, X_dgCMatrix)
preds_mat_dgCMatrix <- predict(bst_dgCMatrix, X_mat)
preds_dgCMatrix_dgCMatrix <- predict(bst_dgCMatrix, X_dgCMatrix)

stopifnot(
    all(
        all(preds_mat_mat == preds_dgCMatrix_mat)
        , all(preds_dgCMatrix_mat == preds_mat_dgCMatrix)
        , all(preds_mat_dgCMatrix == preds_dgCMatrix_dgCMatrix)
    )
)

If you find a case where this is not true and LightGBM is creating different predictions for sparse and, I'd consider that a bug worth addressing.

@shiyu1994 @guolinke @StrikerRUS please correct me if I've misspoken.

Copy link
Contributor Author

@david-cortes david-cortes Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, would be useful to have that in the docs, since xgboost works differently (treats non-present sparse entries as missing instead of as zeros) and one might assume both libraries would work the same way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting, I did not know that. https://xgboost.readthedocs.io/en/stable/faq.html#why-do-i-see-different-results-with-sparse-and-dense-data

“Sparse” elements are treated as if they were “missing” by the tree booster, and as zeros by the linear booster. For tree models, it is important to use consistent data formats during training and scoring.


would be useful to have that in the docs

LightGBM's documentation does already describe this behavior directly. Please see https://lightgbm.readthedocs.io/en/latest/Advanced-Topics.html#missing-value-handle

  • LightGBM uses NA (NaN) to represent missing values by default. Change it to use zero by setting zero_as_missing=true
  • When zero_as_missing=false (default), the unrecorded values in sparse matrices (and LightSVM) are treated as zeros.

expect_equal(Matrix::t(as(pred_spv, "CsparseMatrix")), unname(pred_csc[1L, , drop = FALSE]))
})

test_that("Sparse feature contribution predictions do not take inputs with wrong number of columns", {
data(mtcars)
X <- as.matrix(mtcars[, -1L])
y <- as.numeric(mtcars[, 1L])
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "regression"
, nrounds = 5L
, verbose = VERBOSITY
, params = list(min_data_in_leaf = 5L)
)

X_wrong <- X[, c(1L:10L, 1L:10L)]
X_wrong <- as(X_wrong, "CsparseMatrix")
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns")

X_wrong <- as(X_wrong, "RsparseMatrix")
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns")

X_wrong <- as(X_wrong, "CsparseMatrix")
X_wrong <- X_wrong[, 1L:3L]
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 3 columns")
})

test_that("Feature contribution predictions do not take non-general CSR or CSC inputs", {
set.seed(123L)
y <- runif(25L)
Dmat <- matrix(runif(625L), nrow = 25L, ncol = 25L)
Dmat <- crossprod(Dmat)
Dmat <- as(Dmat, "symmetricMatrix")
SmatC <- as(Dmat, "sparseMatrix")
SmatR <- as(SmatC, "RsparseMatrix")

dtrain <- lgb.Dataset(as.matrix(Dmat), label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "regression"
, nrounds = 5L
, verbose = VERBOSITY
, params = list(min_data_in_leaf = 5L)
)

expect_error(predict(bst, SmatC, predcontrib = TRUE))
expect_error(predict(bst, SmatR, predcontrib = TRUE))
})

test_that("predict() params should override keyword argument for raw-score predictions", {
data(agaricus.train, package = "lightgbm")
X <- agaricus.train$data
Expand Down Expand Up @@ -321,6 +401,8 @@ test_that("predict() params should override keyword argument for feature contrib
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, predcontrib = TRUE)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, as(Xcsc, "RsparseMatrix"), predcontrib = TRUE)
.expect_has_row_names(pred, Xcsc)

# sparse matrix without row names
Xcopy <- Xcsc
Expand Down