Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Clarify the fact that predict() on a file does not support saved Datasets (fixes #4034) #4545

Merged
merged 10 commits into from
Aug 25, 2021
3 changes: 2 additions & 1 deletion R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,8 @@ Booster <- R6::R6Class(
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
#' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
Expand Down
7 changes: 5 additions & 2 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,9 @@ Dataset <- R6::R6Class(
#' @title Construct \code{lgb.Dataset} object
#' @description Construct \code{lgb.Dataset} object from dense matrix, sparse matrix
#' or local file (that was created previously by saving an \code{lgb.Dataset}).
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
#' or a LightGBM Dataset binary file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make identical descriptions really identical without paraphrasing. Users might spend their time trying to find a difference where there is actually no any difference. Also, it will help us to not miss all occurrences of identical parameters during possible future updates.

Suggested change
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
#' or a LightGBM Dataset binary file
#' a character representing a path to a text file (CSV, TSV, or LibSVM),
#' or a character representing a path to a binary \code{Dataset} file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sure, seems fine to me!

I'll do this one manually instead of applying in the browser, since the corresponding lgb.Dataset.Rd will also have to be regenerated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in de189a6, thanks as always for being thorough! You're right, that will make it easier to catch all such phrases in the future.

#' @param params a list of parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{
#' The "Dataset Parameters" section of the documentation} for a list of parameters
Expand Down Expand Up @@ -774,7 +776,8 @@ lgb.Dataset <- function(data,
#' @title Construct validation data
#' @description Construct validation data according to training data
#' @param dataset \code{lgb.Dataset} object, training data
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
#' @param info a list of information of the \code{lgb.Dataset} object
#' @param ... other information to pass to \code{info}.
#'
Expand Down
4 changes: 3 additions & 1 deletion R-package/man/lgb.Dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion R-package/man/lgb.Dataset.create.valid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
context("Booster")

TOLERANCE <- 1e-6

test_that("Booster$finalize() should not fail", {
X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L)
y <- iris[["Sepal.Length"]]
Expand Down Expand Up @@ -419,6 +421,54 @@ test_that("Creating a Booster from a Dataset with an existing predictor should w
expect_equal(bst_from_ds$current_iter(), nrounds)
})

test_that("Booster$eval() should work on a Dataset stored in a binary file", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)

bst <- lgb.train(
params = list(
objective = "regression"
, metric = "l2"
, num_leaves = 4L
)
, data = dtrain
, nrounds = 2L
)

data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(
dataset = dtrain
, data = test$data
, label = test$label
)
dtest$construct()

eval_in_mem <- bst$eval(
data = dtest
, name = "test"
)

test_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = test_file
)
rm(dtest)

eval_from_file <- bst$eval(
data = lgb.Dataset(
data = test_file
)$construct()
, name = "test"
)

expect_true(abs(eval_in_mem[[1L]][["value"]] - 0.1744423) < TOLERANCE)
expect_identical(eval_in_mem, eval_from_file)
})

test_that("Booster$rollback_one_iter() should work as expected", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
Expand Down
10 changes: 5 additions & 5 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
When data type is string or pathlib.Path, it represents the path of txt file.
When data type is string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
num_iteration : int, optional (default=-1)
Expand Down Expand Up @@ -1132,7 +1132,7 @@ def __init__(self, data, label=None, reference=None,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
reference : Dataset or None, optional (default=None)
Expand Down Expand Up @@ -1776,7 +1776,7 @@ def create_valid(self, data, label=None, weight=None, group=None,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Expand Down Expand Up @@ -3405,7 +3405,7 @@ def predict(self, data, start_iteration=0, num_iteration=None,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
If <= 0, starts from the first iteration.
Expand Down Expand Up @@ -3460,7 +3460,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for refit.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
label : list, numpy 1-D array or pandas Series / one-column DataFrame
Label for refit.
decay_rate : float, optional (default=0.9)
Expand Down
2 changes: 1 addition & 1 deletion src/io/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
int num_col = 0;
DataType type = GetDataType(filename, header, lines, &num_col);
if (type == DataType::INVALID) {
Log::Fatal("Unknown format of training data.");
Log::Fatal("Unknown format of training data. Only CSV, TSV, and LibSVM formats are supported.");
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
}
std::unique_ptr<Parser> ret;
int output_label_index = -1;
Expand Down