From 9597326eec6fd7dc4ae95796b9b2d1d6bf748b53 Mon Sep 17 00:00:00 2001 From: Tony Kenny Date: Tue, 1 Dec 2020 02:01:02 +0000 Subject: [PATCH] [R-package] construct dataset earlier in lgb.train and lgb.cv (fixes #3583) (#3598) * construct dataset earlier in lgb.train and lgb.cv * Update R-package/tests/testthat/test_dataset.R Co-authored-by: James Lamb * Update R-package/R/lgb.cv.R Co-authored-by: James Lamb * Update R-package/R/lgb.train.R Co-authored-by: James Lamb * Update R-package/tests/testthat/test_dataset.R Co-authored-by: James Lamb * fixing lint issues * styling updates * fix failing test Co-authored-by: James Lamb --- R-package/R/lgb.cv.R | 7 +-- R-package/R/lgb.train.R | 6 ++- R-package/tests/testthat/test_dataset.R | 60 +++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index c5878c8b6f07..cc36635f7cc2 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -164,6 +164,10 @@ lgb.cv <- function(params = list() } end_iteration <- begin_iteration + params[["num_iterations"]] - 1L + # Construct datasets, if needed + data$update_params(params = params) + data$construct() + # Check interaction constraints cnames <- NULL if (!is.null(colnames)) { @@ -194,9 +198,6 @@ lgb.cv <- function(params = list() data$set_categorical_feature(categorical_feature) } - # Construct datasets, if needed - data$construct() - # Check for folds if (!is.null(folds)) { diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index 71e2c3ad7ef8..7af0d22e2252 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -142,6 +142,10 @@ lgb.train <- function(params = list(), } end_iteration <- begin_iteration + params[["num_iterations"]] - 1L + # Construct datasets, if needed + data$update_params(params = params) + data$construct() + # Check interaction constraints cnames <- NULL if (!is.null(colnames)) { @@ -167,8 +171,6 @@ lgb.train <- function(params = list(), data$set_categorical_feature(categorical_feature) } - # Construct datasets, if needed - data$construct() valid_contain_train <- FALSE train_data_name <- "train" reduced_valid_sets <- list() diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index d0ac9c0627d2..ea87d7138d70 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -205,3 +205,63 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame expect_identical(new_params[[param_name]], updated_params[[param_name]]) } }) + +test_that("lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file", { + dtest <- lgb.Dataset( + data = test_data + , label = test_label + ) + tmp_file <- tempfile(pattern = "lgb.Dataset_") + lgb.Dataset.save( + dataset = dtest + , fname = tmp_file + ) + + # read from a local file + dtest_read_in <- lgb.Dataset(data = tmp_file) + + param <- list( + objective = "binary" + , metric = "binary_logloss" + , num_leaves = 5L + , learning_rate = 1.0 + ) + + # should be able to train right away + bst <- lgb.train( + params = param + , data = dtest_read_in + ) + + expect_true(lgb.is.Booster(x = bst)) +}) + +test_that("lgb.Dataset: should be able to run lgb.cv() immediately after using lgb.Dataset() on a file", { + dtest <- lgb.Dataset( + data = test_data + , label = test_label + ) + tmp_file <- tempfile(pattern = "lgb.Dataset_") + lgb.Dataset.save( + dataset = dtest + , fname = tmp_file + ) + + # read from a local file + dtest_read_in <- lgb.Dataset(data = tmp_file) + + param <- list( + objective = "binary" + , metric = "binary_logloss" + , num_leaves = 5L + , learning_rate = 1.0 + ) + + # should be able to train right away + bst <- lgb.cv( + params = param + , data = dtest_read_in + ) + + expect_is(bst, "lgb.CVBooster") +})