Skip to content

Commit

Permalink
[R-package] fix segfaults caused by missing Booster and Dataset handl…
Browse files Browse the repository at this point in the history
…es (fixes #4208) (#4586)

* [R-package] fix segfaults caused by missing Booster and Dataset handles (fixes #4208)

* fix test errors

* fixes for cpplint

* Update R-package/tests/testthat/test_dataset.R

Co-authored-by: Nikita Titov <[email protected]>

* fix tests

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* move asserts inside try-catch

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jameslamb and StrikerRUS authored Sep 25, 2021
1 parent d462972 commit f8010d6
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 1 deletion.
11 changes: 10 additions & 1 deletion R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,18 @@ Dataset <- R6::R6Class(
ref_handle <- private$reference$.__enclos_env__$private$get_handle()
}

# Not subsetting
# not subsetting, constructing from raw data
if (is.null(private$used_indices)) {

if (is.null(private$raw_data)) {
stop(paste0(
"Attempting to create a Dataset without any raw data. "
, "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
, "To avoid this error in the future, use lgb.Dataset.save() or "
, "Dataset$save_binary() to save lightgbm Datasets."
))
}

# Are we using a data file?
if (is.character(private$raw_data)) {

Expand Down
53 changes: 53 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ void _DatasetFinalizer(SEXP handle) {
LGBM_DatasetFree_R(handle);
}

void _AssertBoosterHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error(
"Attempting to use a Booster which no longer exists. "
"This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). "
"To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.");
}
}

void _AssertDatasetHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error(
"Attempting to use a Dataset which no longer exists. "
"This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
"To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
}
}

SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
SEXP parameters,
SEXP reference) {
Expand Down Expand Up @@ -172,6 +190,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP len_used_row_indices,
SEXP parameters) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len);
Expand All @@ -195,6 +214,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP feature_names) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
std::vector<const char*> vec_sptr;
int len = static_cast<int>(vec_names.size());
Expand All @@ -211,6 +231,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
SEXP feature_names;
int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
Expand Down Expand Up @@ -258,6 +279,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
SEXP filename) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
filename_ptr));
Expand All @@ -281,6 +303,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP field_data,
SEXP num_element) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int len = Rf_asInteger(num_element);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
if (!strcmp("group", name) || !strcmp("query", name)) {
Expand Down Expand Up @@ -309,6 +332,7 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
SEXP field_name,
SEXP field_data) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
Expand Down Expand Up @@ -343,6 +367,7 @@ SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
SEXP field_name,
SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
Expand Down Expand Up @@ -370,6 +395,7 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,

SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int nrow;
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = nrow;
Expand All @@ -380,6 +406,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int nfeature;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
INTEGER(out)[0] = nfeature;
Expand All @@ -406,6 +433,7 @@ SEXP LGBM_BoosterFree_R(SEXP handle) {
SEXP LGBM_BoosterCreate_R(SEXP train_data,
SEXP parameters) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(train_data);
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
BoosterHandle handle = nullptr;
Expand Down Expand Up @@ -448,6 +476,8 @@ SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
SEXP LGBM_BoosterMerge_R(SEXP handle,
SEXP other_handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertBoosterHandleNotNull(other_handle);
CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
return R_NilValue;
R_API_END();
Expand All @@ -456,6 +486,8 @@ SEXP LGBM_BoosterMerge_R(SEXP handle,
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
SEXP valid_data) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertDatasetHandleNotNull(valid_data);
CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
return R_NilValue;
R_API_END();
Expand All @@ -464,6 +496,8 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle,
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
SEXP train_data) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertDatasetHandleNotNull(train_data);
CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
return R_NilValue;
R_API_END();
Expand All @@ -472,6 +506,7 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP parameters) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
UNPROTECT(1);
Expand All @@ -482,6 +517,7 @@ SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int num_class;
CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
INTEGER(out)[0] = num_class;
Expand All @@ -491,6 +527,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,

SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int is_finished = 0;
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
return R_NilValue;
Expand All @@ -502,6 +539,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
SEXP hess,
SEXP len) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int is_finished = 0;
int int_len = Rf_asInteger(len);
std::vector<float> tgrad(int_len), thess(int_len);
Expand All @@ -517,6 +555,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,

SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
return R_NilValue;
R_API_END();
Expand All @@ -525,6 +564,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out_iteration;
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
INTEGER(out)[0] = out_iteration;
Expand All @@ -535,6 +575,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
return R_NilValue;
Expand All @@ -544,6 +585,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
return R_NilValue;
Expand All @@ -553,6 +595,7 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP eval_names;
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
Expand Down Expand Up @@ -602,6 +645,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle,
SEXP data_idx,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
double* ptr_ret = REAL(out_result);
Expand All @@ -616,6 +660,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
SEXP data_idx,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int64_t len;
CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
INTEGER(out)[0] = static_cast<int>(len);
Expand All @@ -627,6 +672,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle,
SEXP data_idx,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Expand Down Expand Up @@ -659,6 +705,7 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
SEXP parameter,
SEXP result_filename) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
Expand All @@ -680,6 +727,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
SEXP num_iteration,
SEXP out_len) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
Expand All @@ -704,6 +752,7 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = INTEGER(indptr);
const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
Expand Down Expand Up @@ -735,6 +784,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
Expand All @@ -755,6 +805,7 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP feature_importance_type,
SEXP filename) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
UNPROTECT(1);
Expand All @@ -767,6 +818,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
Expand All @@ -791,6 +843,7 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
Expand Down
48 changes: 48 additions & 0 deletions R-package/tests/testthat/test_dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,51 @@ test_that("lgb.Dataset: should be able to create a Dataset from a text file with
expect_identical(dtrain$get_params(), list(header = FALSE))
expect_identical(dtrain$dim(), c(100L, 2L))
})

test_that("Dataset: method calls on a Dataset with a null handle should raise an informative error and not segfault", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
dtrain$construct()
dvalid <- dtrain$create_valid(
data = train$data[seq_len(100L), ]
, label = train$label[seq_len(100L)]
)
dvalid$construct()
tmp_file <- tempfile(fileext = ".rds")
saveRDS(dtrain, tmp_file)
rm(dtrain)
dtrain <- readRDS(tmp_file)
expect_error({
dtrain$construct()
}, regexp = "Attempting to create a Dataset without any raw data")
expect_error({
dtrain$dim()
}, regexp = "cannot get dimensions before dataset has been constructed")
expect_error({
dtrain$get_colnames()
}, regexp = "cannot get column names before dataset has been constructed")
expect_error({
dtrain$save_binary(fname = tempfile(fileext = ".bin"))
}, regexp = "Attempting to create a Dataset without any raw data")
expect_error({
dtrain$set_categorical_feature(categorical_feature = 1L)
}, regexp = "cannot set categorical feature after freeing raw data")
expect_error({
dtrain$set_reference(reference = dvalid)
}, regexp = "cannot set reference after freeing raw data")

tmp_valid_file <- tempfile(fileext = ".rds")
saveRDS(dvalid, tmp_valid_file)
rm(dvalid)
dvalid <- readRDS(tmp_valid_file)
dtrain <- lgb.Dataset(
train$data
, label = train$label
, free_raw_data = FALSE
)
dtrain$construct()
expect_error({
dtrain$set_reference(reference = dvalid)
}, regexp = "cannot get column names before dataset has been constructed")
})
Loading

0 comments on commit f8010d6

Please sign in to comment.