Skip to content

Commit

Permalink
[docs] Clarify the fact that predict() on a file does not support sav…
Browse files Browse the repository at this point in the history
…ed Datasets (fixes #4034) (#4545)

* documentation changes

* add list of supported formats to error message

* add unit tests

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* update per review comments

* make references consistent

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jameslamb and StrikerRUS authored Aug 25, 2021
1 parent 11d7608 commit 417ba19
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 13 deletions.
3 changes: 2 additions & 1 deletion R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,8 @@ Booster <- R6::R6Class(
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
#' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
Expand Down
8 changes: 6 additions & 2 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,9 @@ Dataset <- R6::R6Class(
#' @title Construct \code{lgb.Dataset} object
#' @description Construct \code{lgb.Dataset} object from dense matrix, sparse matrix
#' or local file (that was created previously by saving an \code{lgb.Dataset}).
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param data a \code{matrix} object, a \code{dgCMatrix} object,
#' a character representing a path to a text file (CSV, TSV, or LibSVM),
#' or a character representing a path to a binary \code{lgb.Dataset} file
#' @param params a list of parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{
#' The "Dataset Parameters" section of the documentation} for a list of parameters
Expand Down Expand Up @@ -774,7 +776,9 @@ lgb.Dataset <- function(data,
#' @title Construct validation data
#' @description Construct validation data according to training data
#' @param dataset \code{lgb.Dataset} object, training data
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param data a \code{matrix} object, a \code{dgCMatrix} object,
#' a character representing a path to a text file (CSV, TSV, or LibSVM),
#' or a character representing a path to a binary \code{Dataset} file
#' @param info a list of information of the \code{lgb.Dataset} object
#' @param ... other information to pass to \code{info}.
#'
Expand Down
4 changes: 3 additions & 1 deletion R-package/man/lgb.Dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion R-package/man/lgb.Dataset.create.valid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
context("Booster")

TOLERANCE <- 1e-6

test_that("Booster$finalize() should not fail", {
X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L)
y <- iris[["Sepal.Length"]]
Expand Down Expand Up @@ -419,6 +421,54 @@ test_that("Creating a Booster from a Dataset with an existing predictor should w
expect_equal(bst_from_ds$current_iter(), nrounds)
})

test_that("Booster$eval() should work on a Dataset stored in a binary file", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)

bst <- lgb.train(
params = list(
objective = "regression"
, metric = "l2"
, num_leaves = 4L
)
, data = dtrain
, nrounds = 2L
)

data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(
dataset = dtrain
, data = test$data
, label = test$label
)
dtest$construct()

eval_in_mem <- bst$eval(
data = dtest
, name = "test"
)

test_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = test_file
)
rm(dtest)

eval_from_file <- bst$eval(
data = lgb.Dataset(
data = test_file
)$construct()
, name = "test"
)

expect_true(abs(eval_in_mem[[1L]][["value"]] - 0.1744423) < TOLERANCE)
expect_identical(eval_in_mem, eval_from_file)
})

test_that("Booster$rollback_one_iter() should work as expected", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
Expand Down
2 changes: 1 addition & 1 deletion docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Data Interface

The LightGBM Python module can load data from:

- LibSVM (zero-based) / TSV / CSV / TXT format file
- LibSVM (zero-based) / TSV / CSV format text file

- NumPy 2D array(s), pandas DataFrame, H2O DataTable's Frame, SciPy sparse matrix

Expand Down
10 changes: 5 additions & 5 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
When data type is string or pathlib.Path, it represents the path of txt file.
When data type is string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
num_iteration : int, optional (default=-1)
Expand Down Expand Up @@ -1132,7 +1132,7 @@ def __init__(self, data, label=None, reference=None,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
reference : Dataset or None, optional (default=None)
Expand Down Expand Up @@ -1776,7 +1776,7 @@ def create_valid(self, data, label=None, weight=None, group=None,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Expand Down Expand Up @@ -3414,7 +3414,7 @@ def predict(self, data, start_iteration=0, num_iteration=None,
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
If <= 0, starts from the first iteration.
Expand Down Expand Up @@ -3469,7 +3469,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for refit.
If string or pathlib.Path, it represents the path to txt file.
If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
label : list, numpy 1-D array or pandas Series / one-column DataFrame
Label for refit.
decay_rate : float, optional (default=0.9)
Expand Down
2 changes: 1 addition & 1 deletion src/io/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
int num_col = 0;
DataType type = GetDataType(filename, header, lines, &num_col);
if (type == DataType::INVALID) {
Log::Fatal("Unknown format of training data.");
Log::Fatal("Unknown format of training data. Only CSV, TSV, and LibSVM (zero-based) formatted text files are supported.");
}
std::unique_ptr<Parser> ret;
int output_label_index = -1;
Expand Down

0 comments on commit 417ba19

Please sign in to comment.