Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] construct dataset earlier in lgb.train and lgb.cv (fixes #3583) #3598

Merged
merged 9 commits into from
Dec 1, 2020
7 changes: 4 additions & 3 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ lgb.cv <- function(params = list()
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L

# Construct datasets, if needed
tonyk7440 marked this conversation as resolved.
Show resolved Hide resolved
data$update_params(params = params)
data$construct()

# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
Expand Down Expand Up @@ -194,9 +198,6 @@ lgb.cv <- function(params = list()
data$set_categorical_feature(categorical_feature)
}

# Construct datasets, if needed
data$construct()

# Check for folds
if (!is.null(folds)) {

Expand Down
6 changes: 4 additions & 2 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ lgb.train <- function(params = list(),
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L

# Construct datasets, if needed
tonyk7440 marked this conversation as resolved.
Show resolved Hide resolved
data$update_params(params = params)
data$construct()

# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
Expand All @@ -167,8 +171,6 @@ lgb.train <- function(params = list(),
data$set_categorical_feature(categorical_feature)
}

# Construct datasets, if needed
data$construct()
valid_contain_train <- FALSE
train_data_name <- "train"
reduced_valid_sets <- list()
Expand Down
60 changes: 60 additions & 0 deletions R-package/tests/testthat/test_dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,63 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame
expect_identical(new_params[[param_name]], updated_params[[param_name]])
}
})

test_that("lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)

# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)

param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)

# should be able to train right away
bst <- lgb.train(
params = param
, data = dtest_read_in
)

expect_true(lgb.is.Booster(x = bst))
})

test_that("lgb.Dataset: should be able to run lgb.cv() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)

# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)

param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)

# should be able to train right away
bst <- lgb.cv(
params = param
, data = dtest_read_in
)

expect_is(bst, "lgb.CVBooster")
})