Skip to content

Commit

Permalink
Rename models to tree and shorter num_total_models -> num_trees
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Jun 22, 2024
1 parent 18ec224 commit b53ed9e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
26 changes: 13 additions & 13 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -307,43 +307,43 @@ Booster <- R6::R6Class(

},

# Number of models (~trees) per iteration
num_models_per_iter = function() {
# Number of trees per iteration
num_trees_per_iter = function() {

self$restore_handle()

models_per_iter <- 1L
trees_per_iter <- 1L
.Call(
LGBM_BoosterNumModelPerIteration_R
, private$handle
, models_per_iter
, trees_per_iter
)
return(models_per_iter)
return(trees_per_iter)

},

# Total number of models (~trees)
num_total_models = function() {
# Total number of trees
num_trees = function() {

self$restore_handle()

total_models <- 0L
ntrees <- 0L
.Call(
LGBM_BoosterNumberOfTotalModel_R
, private$handle
, total_models
, ntrees
)
return(total_models)
return(ntrees)

},

# Number of iterations (= rounds)
num_iter = function() {

total_models <- self$num_total_models()
models_per_iter <- self$num_models_per_iter()
ntrees <- self$num_trees()
trees_per_iter <- self$num_trees_per_iter()

return(total_models / models_per_iter)
return(ntrees / trees_per_iter)

},

Expand Down
8 changes: 4 additions & 4 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,9 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
);

/*!
* \brief Get number of models (trees) per iteration
* \brief Get number of trees per iteration
* \param handle Booster handle
* \param out Number of models (trees) per iteration
* \param out Number of trees per iteration
* \return R NULL value
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterNumModelPerIteration_R(
Expand All @@ -396,9 +396,9 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterNumModelPerIteration_R(
);

/*!
* \brief Get total number of models (trees)
* \brief Get total number of trees
* \param handle Booster handle
* \param out Total number of models (trees) of Booster
* \param out Total number of trees of Booster
* \return R NULL value
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterNumberOfTotalModel_R(
Expand Down
24 changes: 12 additions & 12 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
})

test_that("Booster$num_models_per_iter() works as expected", {
test_that("Booster$num_trees_per_iter() works as expected", {
set.seed(708L)

X <- data.matrix(iris[2L:4L])
Expand Down Expand Up @@ -667,12 +667,12 @@ test_that("Booster$num_models_per_iter() works as expected", {
, nrounds = nrounds
)

expect_equal(fit_reg$num_models_per_iter(), 1L)
expect_equal(fit_binary$num_models_per_iter(), 1L)
expect_equal(fit_class$num_models_per_iter(), num_class)
expect_equal(fit_reg$num_trees_per_iter(), 1L)
expect_equal(fit_binary$num_trees_per_iter(), 1L)
expect_equal(fit_class$num_trees_per_iter(), num_class)
})

test_that("Booster$num_total_models() and $num_iter() works (no early stopping)", {
test_that("Booster$num_trees() and $num_iter() works (no early stopping)", {
set.seed(708L)

X <- data.matrix(iris[2L:4L])
Expand Down Expand Up @@ -715,16 +715,16 @@ test_that("Booster$num_total_models() and $num_iter() works (no early stopping)"
, nrounds = nrounds
)

expect_equal(fit_reg$num_total_models(), nrounds)
expect_equal(fit_binary$num_total_models(), nrounds)
expect_equal(fit_class$num_total_models(), num_class * nrounds)
expect_equal(fit_reg$num_trees(), nrounds)
expect_equal(fit_binary$num_trees(), nrounds)
expect_equal(fit_class$num_trees(), num_class * nrounds)

expect_equal(fit_reg$num_iter(), nrounds)
expect_equal(fit_binary$num_iter(), nrounds)
expect_equal(fit_class$num_iter(), nrounds)
})

test_that("Booster$num_total_models() and $num_iter() work (with early stopping)", {
test_that("Booster$num_trees() and $num_iter() work (with early stopping)", {
set.seed(708L)

X <- data.matrix(iris[2L:4L])
Expand Down Expand Up @@ -782,9 +782,9 @@ test_that("Booster$num_total_models() and $num_iter() work (with early stopping)
expected_trees_binary <- fit_binary$best_iter + early_stopping
expected_trees_class <- (fit_class$best_iter + early_stopping) * num_class

expect_equal(fit_reg$num_total_models(), expected_trees_reg)
expect_equal(fit_binary$num_total_models(), expected_trees_binary)
expect_equal(fit_class$num_total_models(), expected_trees_class)
expect_equal(fit_reg$num_trees(), expected_trees_reg)
expect_equal(fit_binary$num_trees(), expected_trees_binary)
expect_equal(fit_class$num_trees(), expected_trees_class)

expect_equal(fit_reg$num_iter(), expected_trees_reg)
expect_equal(fit_binary$num_iter(), expected_trees_binary)
Expand Down

0 comments on commit b53ed9e

Please sign in to comment.