diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index fbc84bc4ed46..b1662e413b7b 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -279,3 +279,66 @@ test_that("Creating a Booster from a Dataset with an existing predictor should w expect_identical(bst_from_ds$eval_train(), list()) expect_equal(bst_from_ds$current_iter(), nrounds) }) + +test_that("Booster$update() passing a train_set works as expected", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + nrounds <- 2L + + # train with 2 rounds and then update + bst <- lightgbm( + data = as.matrix(agaricus.train$data) + , label = agaricus.train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = nrounds + , objective = "binary" + ) + expect_true(lgb.is.Booster(bst)) + expect_equal(bst$current_iter(), nrounds) + bst$update( + train_set = Dataset$new( + data = agaricus.train$data + , label = agaricus.train$label + ) + ) + expect_true(lgb.is.Booster(bst)) + expect_equal(bst$current_iter(), nrounds + 1L) + + # train with 3 rounds directlry + bst2 <- lightgbm( + data = as.matrix(agaricus.train$data) + , label = agaricus.train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = nrounds + 1L + , objective = "binary" + ) + expect_true(lgb.is.Booster(bst2)) + expect_equal(bst2$current_iter(), nrounds + 1L) + + # model with 2 rounds + 1 update should be identical to 3 rounds + expect_equal(bst2$eval_train()[[1L]][["value"]], 0.04806585) + expect_equal(bst$eval_train()[[1L]][["value"]], bst2$eval_train()[[1L]][["value"]]) +}) + +test_that("Booster$update() throws an informative error if you provide a non-Dataset to update()", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + nrounds <- 2L + + # train with 2 rounds and then update + bst <- lightgbm( + data = as.matrix(agaricus.train$data) + , label = agaricus.train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = nrounds + , objective = "binary" + ) + expect_error({ + bst$update( + train_set = data.frame(x = rnorm(10L)) + ) + }, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE) +})