Skip to content

Commit

Permalink
[R-package] construct dataset earlier in lgb.train and lgb.cv (fixes #…
Browse files Browse the repository at this point in the history
…3583) (#3598)

* construct dataset earlier in lgb.train and lgb.cv

* Update R-package/tests/testthat/test_dataset.R

Co-authored-by: James Lamb <[email protected]>

* Update R-package/R/lgb.cv.R

Co-authored-by: James Lamb <[email protected]>

* Update R-package/R/lgb.train.R

Co-authored-by: James Lamb <[email protected]>

* Update R-package/tests/testthat/test_dataset.R

Co-authored-by: James Lamb <[email protected]>

* fixing lint issues

* styling updates

* fix failing test

Co-authored-by: James Lamb <[email protected]>
  • Loading branch information
tonyk7440 and jameslamb authored Dec 1, 2020
1 parent c02917e commit 9597326
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
7 changes: 4 additions & 3 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ lgb.cv <- function(params = list()
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L

# Construct datasets, if needed
data$update_params(params = params)
data$construct()

# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
Expand Down Expand Up @@ -194,9 +198,6 @@ lgb.cv <- function(params = list()
data$set_categorical_feature(categorical_feature)
}

# Construct datasets, if needed
data$construct()

# Check for folds
if (!is.null(folds)) {

Expand Down
6 changes: 4 additions & 2 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ lgb.train <- function(params = list(),
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L

# Construct datasets, if needed
data$update_params(params = params)
data$construct()

# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
Expand All @@ -167,8 +171,6 @@ lgb.train <- function(params = list(),
data$set_categorical_feature(categorical_feature)
}

# Construct datasets, if needed
data$construct()
valid_contain_train <- FALSE
train_data_name <- "train"
reduced_valid_sets <- list()
Expand Down
60 changes: 60 additions & 0 deletions R-package/tests/testthat/test_dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,63 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame
expect_identical(new_params[[param_name]], updated_params[[param_name]])
}
})

test_that("lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)

# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)

param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)

# should be able to train right away
bst <- lgb.train(
params = param
, data = dtest_read_in
)

expect_true(lgb.is.Booster(x = bst))
})

test_that("lgb.Dataset: should be able to run lgb.cv() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)

# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)

param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)

# should be able to train right away
bst <- lgb.cv(
params = param
, data = dtest_read_in
)

expect_is(bst, "lgb.CVBooster")
})

0 comments on commit 9597326

Please sign in to comment.