diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index bce9e785bf4e..a13516ff6569 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -307,6 +307,46 @@ Booster <- R6::R6Class( }, + # Number of trees per iteration + num_trees_per_iter = function() { + + self$restore_handle() + + trees_per_iter <- 1L + .Call( + LGBM_BoosterNumModelPerIteration_R + , private$handle + , trees_per_iter + ) + return(trees_per_iter) + + }, + + # Total number of trees + num_trees = function() { + + self$restore_handle() + + ntrees <- 0L + .Call( + LGBM_BoosterNumberOfTotalModel_R + , private$handle + , ntrees + ) + return(ntrees) + + }, + + # Number of iterations (= rounds) + num_iter = function() { + + ntrees <- self$num_trees() + trees_per_iter <- self$num_trees_per_iter() + + return(ntrees / trees_per_iter) + + }, + # Get upper bound upper_bound = function() { diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index e8383c5c366e..045f0a9da04b 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -763,8 +763,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { R_API_END(); } -SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, - SEXP out) { +SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP out) { R_API_BEGIN(); _AssertBoosterHandleNotNull(handle); int out_iteration; @@ -774,6 +773,26 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, R_API_END(); } +SEXP LGBM_BoosterNumModelPerIteration_R(SEXP handle, SEXP out) { + R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); + int models_per_iter; + CHECK_CALL(LGBM_BoosterNumModelPerIteration(R_ExternalPtrAddr(handle), &models_per_iter)); + INTEGER(out)[0] = models_per_iter; + return R_NilValue; + R_API_END(); +} + +SEXP LGBM_BoosterNumberOfTotalModel_R(SEXP handle, SEXP out) { + R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); + int total_models; + CHECK_CALL(LGBM_BoosterNumberOfTotalModel(R_ExternalPtrAddr(handle), &total_models)); + INTEGER(out)[0] = total_models; + return R_NilValue; + R_API_END(); +} + SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP out_result) { R_API_BEGIN(); @@ -1431,6 +1450,8 @@ static const R_CallMethodDef CallEntries[] = { {"LGBM_BoosterUpdateOneIterCustom_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R , 4}, {"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1}, {"LGBM_BoosterGetCurrentIteration_R" , (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R , 2}, + {"LGBM_BoosterNumModelPerIteration_R" , (DL_FUNC) &LGBM_BoosterNumModelPerIteration_R , 2}, + {"LGBM_BoosterNumberOfTotalModel_R" , (DL_FUNC) &LGBM_BoosterNumberOfTotalModel_R , 2}, {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2}, {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2}, {"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 1}, diff --git a/R-package/src/lightgbm_R.h b/R-package/src/lightgbm_R.h index 574c9733acd9..9d100d095e9e 100644 --- a/R-package/src/lightgbm_R.h +++ b/R-package/src/lightgbm_R.h @@ -384,6 +384,28 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R( SEXP out ); +/*! + * \brief Get number of trees per iteration + * \param handle Booster handle + * \param out Number of trees per iteration + * \return R NULL value + */ +LIGHTGBM_C_EXPORT SEXP LGBM_BoosterNumModelPerIteration_R( + SEXP handle, + SEXP out +); + +/*! + * \brief Get total number of trees + * \param handle Booster handle + * \param out Total number of trees of Booster + * \return R NULL value + */ +LIGHTGBM_C_EXPORT SEXP LGBM_BoosterNumberOfTotalModel_R( + SEXP handle, + SEXP out +); + /*! * \brief Get model upper bound value. * \param handle Handle of Booster diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index e81dc89673e0..80cf7775813d 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -623,6 +623,174 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat }, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE) }) +test_that("Booster$num_trees_per_iter() works as expected", { + set.seed(708L) + + X <- data.matrix(iris[2L:4L]) + y_reg <- iris[, 1L] + y_binary <- as.integer(y_reg > median(y_reg)) + y_class <- as.integer(iris[, 5L]) - 1L + num_class <- 3L + + nrounds <- 10L + + # Regression and binary probabilistic classification (1 iteration = 1 tree) + fit_reg <- lgb.train( + params = list( + objective = "mse" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + ) + , data = lgb.Dataset(X, label = y_reg) + , nrounds = nrounds + ) + + fit_binary <- lgb.train( + params = list( + objective = "binary" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + ) + , data = lgb.Dataset(X, label = y_binary) + , nrounds = nrounds + ) + + # Multiclass probabilistic classification (1 iteration = num_class trees) + fit_class <- lgb.train( + params = list( + objective = "multiclass" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + , num_class = num_class + ) + , data = lgb.Dataset(X, label = y_class) + , nrounds = nrounds + ) + + expect_equal(fit_reg$num_trees_per_iter(), 1L) + expect_equal(fit_binary$num_trees_per_iter(), 1L) + expect_equal(fit_class$num_trees_per_iter(), num_class) +}) + +test_that("Booster$num_trees() and $num_iter() works (no early stopping)", { + set.seed(708L) + + X <- data.matrix(iris[2L:4L]) + y_reg <- iris[, 1L] + y_binary <- as.integer(y_reg > median(y_reg)) + y_class <- as.integer(iris[, 5L]) - 1L + num_class <- 3L + nrounds <- 10L + + # Regression and binary probabilistic classification (1 iteration = 1 tree) + fit_reg <- lgb.train( + params = list( + objective = "mse" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + ) + , data = lgb.Dataset(X, label = y_reg) + , nrounds = nrounds + ) + + fit_binary <- lgb.train( + params = list( + objective = "binary" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + ) + , data = lgb.Dataset(X, label = y_binary) + , nrounds = nrounds + ) + + # Multiclass probabilistic classification (1 iteration = num_class trees) + fit_class <- lgb.train( + params = list( + objective = "multiclass" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + , num_class = num_class + ) + , data = lgb.Dataset(X, label = y_class) + , nrounds = nrounds + ) + + expect_equal(fit_reg$num_trees(), nrounds) + expect_equal(fit_binary$num_trees(), nrounds) + expect_equal(fit_class$num_trees(), num_class * nrounds) + + expect_equal(fit_reg$num_iter(), nrounds) + expect_equal(fit_binary$num_iter(), nrounds) + expect_equal(fit_class$num_iter(), nrounds) +}) + +test_that("Booster$num_trees() and $num_iter() work (with early stopping)", { + set.seed(708L) + + X <- data.matrix(iris[2L:4L]) + y_reg <- iris[, 1L] + y_binary <- as.integer(y_reg > median(y_reg)) + y_class <- as.integer(iris[, 5L]) - 1L + train_ix <- c(1L:40L, 51L:90L, 101L:140L) + X_train <- X[train_ix, ] + X_valid <- X[-train_ix, ] + + num_class <- 3L + nrounds <- 1000L + early_stopping <- 2L + + # Regression and binary probabilistic classification (1 iteration = 1 tree) + fit_reg <- lgb.train( + params = list( + objective = "mse" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + ) + , data = lgb.Dataset(X_train, label = y_reg[train_ix]) + , valids = list(valid = lgb.Dataset(X_valid, label = y_reg[-train_ix])) + , nrounds = nrounds + , early_stopping_round = early_stopping + ) + + fit_binary <- lgb.train( + params = list( + objective = "binary" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + ) + , data = lgb.Dataset(X_train, label = y_binary[train_ix]) + , valids = list(valid = lgb.Dataset(X_valid, label = y_binary[-train_ix])) + , nrounds = nrounds + , early_stopping_round = early_stopping + ) + + # Multiclass probabilistic classification (1 iteration = num_class trees) + fit_class <- lgb.train( + params = list( + objective = "multiclass" + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + , num_class = num_class + ) + , data = lgb.Dataset(X_train, label = y_class[train_ix]) + , valids = list(valid = lgb.Dataset(X_valid, label = y_class[-train_ix])) + , nrounds = nrounds + , early_stopping_round = early_stopping + ) + + expected_trees_reg <- fit_reg$best_iter + early_stopping + expected_trees_binary <- fit_binary$best_iter + early_stopping + expected_trees_class <- (fit_class$best_iter + early_stopping) * num_class + + expect_equal(fit_reg$num_trees(), expected_trees_reg) + expect_equal(fit_binary$num_trees(), expected_trees_binary) + expect_equal(fit_class$num_trees(), expected_trees_class) + + expect_equal(fit_reg$num_iter(), expected_trees_reg) + expect_equal(fit_binary$num_iter(), expected_trees_binary) + expect_equal(fit_class$num_iter(), expected_trees_class / num_class) +}) + test_that("Booster should store parameters and Booster$reset_parameter() should update them", { data(agaricus.train, package = "lightgbm") dtrain <- lgb.Dataset(