From 417ba1921797d3dc5787ae8d297dcbbfc7bd7184 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 25 Aug 2021 03:33:13 +0100 Subject: [PATCH] [docs] Clarify the fact that predict() on a file does not support saved Datasets (fixes #4034) (#4545) * documentation changes * add list of supported formats to error message * add unit tests * Apply suggestions from code review Co-authored-by: Nikita Titov * update per review comments * make references consistent Co-authored-by: Nikita Titov --- R-package/R/lgb.Booster.R | 3 +- R-package/R/lgb.Dataset.R | 8 +++- R-package/man/lgb.Dataset.Rd | 4 +- R-package/man/lgb.Dataset.create.valid.Rd | 4 +- R-package/man/predict.lgb.Booster.Rd | 3 +- R-package/tests/testthat/test_lgb.Booster.R | 50 +++++++++++++++++++++ docs/Python-Intro.rst | 2 +- python-package/lightgbm/basic.py | 10 ++--- src/io/parser.cpp | 2 +- 9 files changed, 73 insertions(+), 13 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index d9e0186f97b1..507e5e01085b 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -682,7 +682,8 @@ Booster <- R6::R6Class( #' @title Predict method for LightGBM model #' @description Predicted values based on class \code{lgb.Booster} #' @param object Object of class \code{lgb.Booster} -#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename +#' @param data a \code{matrix} object, a \code{dgCMatrix} object or +#' a character representing a path to a text file (CSV, TSV, or LibSVM) #' @param start_iteration int or None, optional (default=None) #' Start index of the iteration to predict. #' If None or <= 0, starts from the first iteration. diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index def2d2ebecf1..e3081e7de0d6 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -710,7 +710,9 @@ Dataset <- R6::R6Class( #' @title Construct \code{lgb.Dataset} object #' @description Construct \code{lgb.Dataset} object from dense matrix, sparse matrix #' or local file (that was created previously by saving an \code{lgb.Dataset}). -#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename +#' @param data a \code{matrix} object, a \code{dgCMatrix} object, +#' a character representing a path to a text file (CSV, TSV, or LibSVM), +#' or a character representing a path to a binary \code{lgb.Dataset} file #' @param params a list of parameters. See #' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{ #' The "Dataset Parameters" section of the documentation} for a list of parameters @@ -774,7 +776,9 @@ lgb.Dataset <- function(data, #' @title Construct validation data #' @description Construct validation data according to training data #' @param dataset \code{lgb.Dataset} object, training data -#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename +#' @param data a \code{matrix} object, a \code{dgCMatrix} object, +#' a character representing a path to a text file (CSV, TSV, or LibSVM), +#' or a character representing a path to a binary \code{Dataset} file #' @param info a list of information of the \code{lgb.Dataset} object #' @param ... other information to pass to \code{info}. #' diff --git a/R-package/man/lgb.Dataset.Rd b/R-package/man/lgb.Dataset.Rd index 4a5abcf78f2c..cb71120142d3 100644 --- a/R-package/man/lgb.Dataset.Rd +++ b/R-package/man/lgb.Dataset.Rd @@ -16,7 +16,9 @@ lgb.Dataset( ) } \arguments{ -\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename} +\item{data}{a \code{matrix} object, a \code{dgCMatrix} object, +a character representing a path to a text file (CSV, TSV, or LibSVM), +or a character representing a path to a binary \code{lgb.Dataset} file} \item{params}{a list of parameters. See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{ diff --git a/R-package/man/lgb.Dataset.create.valid.Rd b/R-package/man/lgb.Dataset.create.valid.Rd index ce34908e1828..d0fe428d6b18 100644 --- a/R-package/man/lgb.Dataset.create.valid.Rd +++ b/R-package/man/lgb.Dataset.create.valid.Rd @@ -9,7 +9,9 @@ lgb.Dataset.create.valid(dataset, data, info = list(), ...) \arguments{ \item{dataset}{\code{lgb.Dataset} object, training data} -\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename} +\item{data}{a \code{matrix} object, a \code{dgCMatrix} object, +a character representing a path to a text file (CSV, TSV, or LibSVM), +or a character representing a path to a binary \code{Dataset} file} \item{info}{a list of information of the \code{lgb.Dataset} object} diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index c1c4cfb0cc77..359eb1c80a0a 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -20,7 +20,8 @@ \arguments{ \item{object}{Object of class \code{lgb.Booster}} -\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename} +\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or +a character representing a path to a text file (CSV, TSV, or LibSVM)} \item{start_iteration}{int or None, optional (default=None) Start index of the iteration to predict. diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 735f2fef9b66..76d3a41c9b5b 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -1,5 +1,7 @@ context("Booster") +TOLERANCE <- 1e-6 + test_that("Booster$finalize() should not fail", { X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) y <- iris[["Sepal.Length"]] @@ -419,6 +421,54 @@ test_that("Creating a Booster from a Dataset with an existing predictor should w expect_equal(bst_from_ds$current_iter(), nrounds) }) +test_that("Booster$eval() should work on a Dataset stored in a binary file", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + + bst <- lgb.train( + params = list( + objective = "regression" + , metric = "l2" + , num_leaves = 4L + ) + , data = dtrain + , nrounds = 2L + ) + + data(agaricus.test, package = "lightgbm") + test <- agaricus.test + dtest <- lgb.Dataset.create.valid( + dataset = dtrain + , data = test$data + , label = test$label + ) + dtest$construct() + + eval_in_mem <- bst$eval( + data = dtest + , name = "test" + ) + + test_file <- tempfile(pattern = "lgb.Dataset_") + lgb.Dataset.save( + dataset = dtest + , fname = test_file + ) + rm(dtest) + + eval_from_file <- bst$eval( + data = lgb.Dataset( + data = test_file + )$construct() + , name = "test" + ) + + expect_true(abs(eval_in_mem[[1L]][["value"]] - 0.1744423) < TOLERANCE) + expect_identical(eval_in_mem, eval_from_file) +}) + test_that("Booster$rollback_one_iter() should work as expected", { set.seed(708L) data(agaricus.train, package = "lightgbm") diff --git a/docs/Python-Intro.rst b/docs/Python-Intro.rst index 063dbf172445..090bbc1c3b54 100644 --- a/docs/Python-Intro.rst +++ b/docs/Python-Intro.rst @@ -33,7 +33,7 @@ Data Interface The LightGBM Python module can load data from: -- LibSVM (zero-based) / TSV / CSV / TXT format file +- LibSVM (zero-based) / TSV / CSV format text file - NumPy 2D array(s), pandas DataFrame, H2O DataTable's Frame, SciPy sparse matrix diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d33ba3fd6ebb..1490fbb7f4db 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -744,7 +744,7 @@ def predict(self, data, start_iteration=0, num_iteration=-1, ---------- data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse Data source for prediction. - When data type is string or pathlib.Path, it represents the path of txt file. + When data type is string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). start_iteration : int, optional (default=0) Start index of the iteration to predict. num_iteration : int, optional (default=-1) @@ -1132,7 +1132,7 @@ def __init__(self, data, label=None, reference=None, ---------- data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays Data source of Dataset. - If string or pathlib.Path, it represents the path to txt file. + If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) Label of the data. reference : Dataset or None, optional (default=None) @@ -1776,7 +1776,7 @@ def create_valid(self, data, label=None, weight=None, group=None, ---------- data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays Data source of Dataset. - If string or pathlib.Path, it represents the path to txt file. + If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) Label of the data. weight : list, numpy 1-D array, pandas Series or None, optional (default=None) @@ -3414,7 +3414,7 @@ def predict(self, data, start_iteration=0, num_iteration=None, ---------- data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse Data source for prediction. - If string or pathlib.Path, it represents the path to txt file. + If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). start_iteration : int, optional (default=0) Start index of the iteration to predict. If <= 0, starts from the first iteration. @@ -3469,7 +3469,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): ---------- data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse Data source for refit. - If string or pathlib.Path, it represents the path to txt file. + If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). label : list, numpy 1-D array or pandas Series / one-column DataFrame Label for refit. decay_rate : float, optional (default=0.9) diff --git a/src/io/parser.cpp b/src/io/parser.cpp index 2dd46adbf937..58f2d5b94467 100644 --- a/src/io/parser.cpp +++ b/src/io/parser.cpp @@ -236,7 +236,7 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features int num_col = 0; DataType type = GetDataType(filename, header, lines, &num_col); if (type == DataType::INVALID) { - Log::Fatal("Unknown format of training data."); + Log::Fatal("Unknown format of training data. Only CSV, TSV, and LibSVM (zero-based) formatted text files are supported."); } std::unique_ptr ret; int output_label_index = -1;