diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 385621729f30..59acd9cd3508 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -11,21 +11,12 @@ Booster <- R6::R6Class( # Finalize will free up the handles finalize = function() { - - # Check the need for freeing handle - if (!lgb.is.null.handle(x = private$handle)) { - - # Freeing up handle - .Call( - LGBM_BoosterFree_R - , private$handle - ) - private$handle <- NULL - - } - + .Call( + LGBM_BoosterFree_R + , private$handle + ) + private$handle <- NULL return(invisible(NULL)) - }, # Initialize will create a starter booster diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index 9665fbdd59d0..cf1adea7a76c 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -8,21 +8,12 @@ Dataset <- R6::R6Class( # Finalize will free up the handles finalize = function() { - - # Check the need for freeing handle - if (!lgb.is.null.handle(x = private$handle)) { - - # Freeing up handle - .Call( - LGBM_DatasetFree_R - , private$handle - ) - private$handle <- NULL - - } - + .Call( + LGBM_DatasetFree_R + , private$handle + ) + private$handle <- NULL return(invisible(NULL)) - }, # Initialize will create a starter dataset diff --git a/R-package/R/lgb.Predictor.R b/R-package/R/lgb.Predictor.R index e819ab4fd837..b1fc5bce1eb8 100644 --- a/R-package/R/lgb.Predictor.R +++ b/R-package/R/lgb.Predictor.R @@ -11,9 +11,8 @@ Predictor <- R6::R6Class( finalize = function() { # Check the need for freeing handle - if (private$need_free_handle && !lgb.is.null.handle(x = private$handle)) { + if (private$need_free_handle) { - # Freeing up handle .Call( LGBM_BoosterFree_R , private$handle diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index 74704d92b1ec..7c5b5387dda9 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -46,6 +46,10 @@ SEXP LGBM_HandleIsNull_R(SEXP handle) { return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL); } +void _DatasetFinalizer(SEXP handle) { + LGBM_DatasetFree_R(handle); +} + SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, SEXP parameters, SEXP reference) { @@ -59,6 +63,7 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)), ref, &handle)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); UNPROTECT(1); return ret; R_API_END(); @@ -90,6 +95,7 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, nrow, CHAR(Rf_asChar(parameters)), ref, &handle)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); UNPROTECT(1); return ret; R_API_END(); @@ -113,6 +119,7 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data, CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, CHAR(Rf_asChar(parameters)), ref, &handle)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); UNPROTECT(1); return ret; R_API_END(); @@ -136,6 +143,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, idxvec.data(), len, CHAR(Rf_asChar(parameters)), &res)); ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); UNPROTECT(1); return ret; R_API_END(); @@ -211,7 +219,7 @@ SEXP LGBM_DatasetSaveBinary_R(SEXP handle, SEXP LGBM_DatasetFree_R(SEXP handle) { R_API_BEGIN(); - if (R_ExternalPtrAddr(handle)) { + if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) { CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle))); R_ClearExternalPtr(handle); } @@ -320,9 +328,13 @@ SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, // --- start Booster interfaces +void _BoosterFinalizer(SEXP handle) { + LGBM_BoosterFree_R(handle); +} + SEXP LGBM_BoosterFree_R(SEXP handle) { R_API_BEGIN(); - if (R_ExternalPtrAddr(handle)) { + if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) { CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle))); R_ClearExternalPtr(handle); } @@ -336,6 +348,7 @@ SEXP LGBM_BoosterCreate_R(SEXP train_data, BoosterHandle handle = nullptr; CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), CHAR(Rf_asChar(parameters)), &handle)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); UNPROTECT(1); return ret; R_API_END(); @@ -348,6 +361,7 @@ SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) { BoosterHandle handle = nullptr; CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); UNPROTECT(1); return ret; R_API_END(); @@ -360,6 +374,7 @@ SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { BoosterHandle handle = nullptr; CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); UNPROTECT(1); return ret; R_API_END(); diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 77719f2367a4..eca50a4f123a 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -1,5 +1,31 @@ context("Predictor") +test_that("Predictor$finalize() should not fail", { + X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) + y <- iris[["Sepal.Length"]] + dtrain <- lgb.Dataset(X, label = y) + bst <- lgb.train( + data = dtrain + , objective = "regression" + , verbose = -1L + , nrounds = 3L + ) + model_file <- tempfile(fileext = ".model") + bst$save_model(filename = model_file) + predictor <- Predictor$new(modelfile = model_file) + + expect_true(lgb.is.Predictor(predictor)) + + expect_false(lgb.is.null.handle(predictor$.__enclos_env__$private$handle)) + + predictor$finalize() + expect_true(lgb.is.null.handle(predictor$.__enclos_env__$private$handle)) + + # calling finalize() a second time shouldn't cause any issues + predictor$finalize() + expect_true(lgb.is.null.handle(predictor$.__enclos_env__$private$handle)) +}) + test_that("predictions do not fail for integer input", { X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) y <- iris[["Sepal.Length"]] diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index 93ccbf23288a..a0b0670f8745 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -215,6 +215,24 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame } }) +test_that("Dataset$finalize() should not fail on an already-finalized Dataset", { + dtest <- lgb.Dataset( + data = test_data + , label = test_label + ) + expect_true(lgb.is.null.handle(dtest$.__enclos_env__$private$handle)) + + dtest$construct() + expect_false(lgb.is.null.handle(dtest$.__enclos_env__$private$handle)) + + dtest$finalize() + expect_true(lgb.is.null.handle(dtest$.__enclos_env__$private$handle)) + + # calling finalize() a second time shouldn't cause any issues + dtest$finalize() + expect_true(lgb.is.null.handle(dtest$.__enclos_env__$private$handle)) +}) + test_that("lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file", { dtest <- lgb.Dataset( data = test_data diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 4f39a036590c..735f2fef9b66 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -1,3 +1,27 @@ +context("Booster") + +test_that("Booster$finalize() should not fail", { + X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) + y <- iris[["Sepal.Length"]] + dtrain <- lgb.Dataset(X, label = y) + bst <- lgb.train( + data = dtrain + , objective = "regression" + , verbose = -1L + , nrounds = 3L + ) + expect_true(lgb.is.Booster(bst)) + + expect_false(lgb.is.null.handle(bst$.__enclos_env__$private$handle)) + + bst$finalize() + expect_true(lgb.is.null.handle(bst$.__enclos_env__$private$handle)) + + # calling finalize() a second time shouldn't cause any issues + bst$finalize() + expect_true(lgb.is.null.handle(bst$.__enclos_env__$private$handle)) +}) + context("lgb.get.eval.result") test_that("lgb.get.eval.result() should throw an informative error if booster is not an lgb.Booster", {