From b7f6311f275eb62989b3977143b0e9335e252202 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 7 Nov 2023 19:14:09 +0100 Subject: [PATCH 1/4] [python-package] Allow to pass Arrow array as labels (#6163) --- include/LightGBM/c_api.h | 17 +++++++++ include/LightGBM/dataset.h | 6 ++++ python-package/lightgbm/basic.py | 45 +++++++++++++++++++----- python-package/lightgbm/compat.py | 14 ++++++++ src/c_api.cpp | 16 +++++++++ src/io/dataset.cpp | 11 ++++++ src/io/metadata.cpp | 28 ++++++++++----- tests/python_package_test/test_arrow.py | 46 +++++++++++++++++++++++++ 8 files changed, 166 insertions(+), 17 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 01f55d5ddfb4..a46f8332811a 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -555,6 +555,23 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, int num_element, int type); +/*! + * \brief Set vector to a content in info. + * \note + * - \a label convert input datatype into ``float32``. + * \param handle Handle of dataset + * \param field_name Field name, can be \a label + * \param n_chunks The number of Arrow arrays passed to this function + * \param chunks Pointer to the list of Arrow arrays + * \param schema Pointer to the schema of all Arrow arrays + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle, + const char* field_name, + int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema); + /*! * \brief Get info vector from dataset. * \param handle Handle of dataset diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index e94e0d943a72..56bc7b841dc3 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -110,6 +110,7 @@ class Metadata { const std::vector& used_data_indices); void SetLabel(const label_t* label, data_size_t len); + void SetLabel(const ArrowChunkedArray& array); void SetWeights(const label_t* weights, data_size_t len); @@ -334,6 +335,9 @@ class Metadata { void CalculateQueryBoundaries(); /*! \brief Insert labels at the given index */ void InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len); + /*! \brief Set labels from pointers to the first element and the end of an iterator. */ + template + void SetLabelsFromIterator(It first, It last); /*! \brief Insert weights at the given index */ void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len); /*! \brief Insert initial scores at the given index */ @@ -655,6 +659,8 @@ class Dataset { LIGHTGBM_EXPORT void FinishLoad(); + bool SetFieldFromArrow(const char* field_name, const ArrowChunkedArray& ca); + LIGHTGBM_EXPORT bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element); LIGHTGBM_EXPORT bool SetDoubleField(const char* field_name, const double* field_data, data_size_t num_element); diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d0d9f0b136f8..702c4682ea8d 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -19,7 +19,7 @@ import scipy.sparse from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat, - dt_DataTable, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series) + dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series) from .libpath import find_lib_path if TYPE_CHECKING: @@ -99,7 +99,9 @@ List[int], np.ndarray, pd_Series, - pd_DataFrame + pd_DataFrame, + pa_Array, + pa_ChunkedArray, ] _LGBM_PredictDataType = Union[ str, @@ -353,6 +355,11 @@ def _is_2d_collection(data: Any) -> bool: ) +def _is_pyarrow_array(data: Any) -> bool: + """Check whether data is a PyArrow array.""" + return isinstance(data, (pa_Array, pa_ChunkedArray)) + + def _is_pyarrow_table(data: Any) -> bool: """Check whether data is a PyArrow table.""" return isinstance(data, pa_Table) @@ -384,7 +391,11 @@ def schema_ptr(self) -> int: def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray: """Export an Arrow type to its C representation.""" # Obtain objects to export - if isinstance(data, pa_Table): + if isinstance(data, pa_Array): + export_objects = [data] + elif isinstance(data, pa_ChunkedArray): + export_objects = data.chunks + elif isinstance(data, pa_Table): export_objects = data.to_batches() else: raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow") @@ -1620,7 +1631,7 @@ def __init__( data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table Data source of Dataset. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. - label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) + label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Label of the data. reference : Dataset or None, optional (default=None) If this is Dataset for validation, training data should be used as reference. @@ -2402,7 +2413,7 @@ def create_valid( data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array Data source of Dataset. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. - label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) + label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Label of the data. weight : list, numpy 1-D array, pandas Series or None, optional (default=None) Weight for each instance. Weights should be non-negative. @@ -2519,7 +2530,7 @@ def _reverse_update_params(self) -> "Dataset": def set_field( self, field_name: str, - data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame]] + data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Array, pa_ChunkedArray]] ) -> "Dataset": """Set property into the Dataset. @@ -2527,7 +2538,7 @@ def set_field( ---------- field_name : str The field name of the information. - data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None + data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray or None The data to be set. Returns @@ -2546,6 +2557,20 @@ def set_field( ctypes.c_int(0), ctypes.c_int(_FIELD_TYPE_MAPPER[field_name]))) return self + + # If the data is a arrow data, we can just pass it to C + if _is_pyarrow_array(data): + c_array = _export_arrow_to_c(data) + _safe_call(_LIB.LGBM_DatasetSetFieldFromArrow( + self._handle, + _c_str(field_name), + ctypes.c_int64(c_array.n_chunks), + ctypes.c_void_p(c_array.chunks_ptr), + ctypes.c_void_p(c_array.schema_ptr), + )) + self.version += 1 + return self + dtype: "np.typing.DTypeLike" if field_name == 'init_score': dtype = np.float64 @@ -2749,7 +2774,7 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset": Parameters ---------- - label : list, numpy 1-D array, pandas Series / one-column DataFrame or None + label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None The label information to be set into Dataset. Returns @@ -2774,6 +2799,8 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset": # data has nullable dtypes, but we can specify na_value argument and copy will be made label = label.to_numpy(dtype=np.float32, na_value=np.nan) label_array = np.ravel(label) + elif _is_pyarrow_array(label): + label_array = label else: label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label') self.set_field('label', label_array) @@ -4353,7 +4380,7 @@ def refit( data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array Data source for refit. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). - label : list, numpy 1-D array or pandas Series / one-column DataFrame + label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray Label for refit. decay_rate : float, optional (default=0.9) Decay rate of refit, diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 7be375e02e85..82d048bb374e 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -187,6 +187,8 @@ def __init__(self, *args, **kwargs): """pyarrow""" try: + from pyarrow import Array as pa_Array + from pyarrow import ChunkedArray as pa_ChunkedArray from pyarrow import Table as pa_Table from pyarrow.cffi import ffi as arrow_cffi from pyarrow.types import is_floating as arrow_is_floating @@ -195,6 +197,18 @@ def __init__(self, *args, **kwargs): except ImportError: PYARROW_INSTALLED = False + class pa_Array: # type: ignore + """Dummy class for pa.Array.""" + + def __init__(self, *args, **kwargs): + pass + + class pa_ChunkedArray: # type: ignore + """Dummy class for pa.ChunkedArray.""" + + def __init__(self, *args, **kwargs): + pass + class pa_Table: # type: ignore """Dummy class for pa.Table.""" diff --git a/src/c_api.cpp b/src/c_api.cpp index 6467bb54a8fe..baf934db42b1 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -833,6 +833,7 @@ class Booster { // explicitly declare symbols from LightGBM namespace using LightGBM::AllgatherFunction; +using LightGBM::ArrowChunkedArray; using LightGBM::ArrowTable; using LightGBM::Booster; using LightGBM::Common::CheckElementsIntervalClosed; @@ -1780,6 +1781,21 @@ int LGBM_DatasetSetField(DatasetHandle handle, API_END(); } +int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle, + const char* field_name, + int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema) { + API_BEGIN(); + auto dataset = reinterpret_cast(handle); + ArrowChunkedArray ca(n_chunks, chunks, schema); + auto is_success = dataset->SetFieldFromArrow(field_name, ca); + if (!is_success) { + Log::Fatal("Input field is not supported"); + } + API_END(); +} + int LGBM_DatasetGetField(DatasetHandle handle, const char* field_name, int* out_len, diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 147765644887..e78f8a6b696c 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -897,6 +897,17 @@ void Dataset::CopySubrow(const Dataset* fullset, #endif // USE_CUDA } +bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray &ca) { + std::string name(field_name); + name = Common::Trim(name); + if (name == std::string("label") || name == std::string("target")) { + metadata_.SetLabel(ca); + } else { + return false; + } + return true; +} + bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) { std::string name(field_name); diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index c9e8973addb4..41f9e3bf43c6 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -403,27 +403,39 @@ void Metadata::InsertInitScores(const double* init_scores, data_size_t start_ind // CUDA is handled after all insertions are complete } -void Metadata::SetLabel(const label_t* label, data_size_t len) { +template +void Metadata::SetLabelsFromIterator(It first, It last) { std::lock_guard lock(mutex_); - if (label == nullptr) { - Log::Fatal("label cannot be nullptr"); + if (num_data_ != last - first) { + Log::Fatal("Length of labels differs from the length of #data"); } - if (num_data_ != len) { - Log::Fatal("Length of label is not same with #data"); + if (label_.empty()) { + label_.resize(num_data_); } - if (label_.empty()) { label_.resize(num_data_); } #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_data_ >= 1024) for (data_size_t i = 0; i < num_data_; ++i) { - label_[i] = Common::AvoidInf(label[i]); + label_[i] = Common::AvoidInf(first[i]); } + #ifdef USE_CUDA if (cuda_metadata_ != nullptr) { - cuda_metadata_->SetLabel(label_.data(), len); + cuda_metadata_->SetLabel(label_.data(), label_.size()); } #endif // USE_CUDA } +void Metadata::SetLabel(const label_t* label, data_size_t len) { + if (label == nullptr) { + Log::Fatal("label cannot be nullptr"); + } + SetLabelsFromIterator(label, label + len); +} + +void Metadata::SetLabel(const ArrowChunkedArray& array) { + SetLabelsFromIterator(array.begin(), array.end()); +} + void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len) { if (labels == nullptr) { Log::Fatal("label cannot be nullptr"); diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 54ca945e1e53..1dd270c8ec53 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -67,6 +67,10 @@ def dummy_dataset_params() -> Dict[str, Any]: } +def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray): + assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs) + + # ----------------------------------------------------------------------------------------------- # # UNIT TESTS # # ----------------------------------------------------------------------------------------------- # @@ -97,3 +101,45 @@ def test_dataset_construct_fuzzy( arrow_dataset._dump_text(tmp_path / "arrow.txt") pandas_dataset._dump_text(tmp_path / "pandas.txt") assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt") + + +@pytest.mark.parametrize( + ["array_type", "label_data"], + [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])], +) +@pytest.mark.parametrize( + "arrow_type", + [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.uint8(), + pa.uint16(), + pa.uint32(), + pa.uint64(), + pa.float32(), + pa.float64(), + ], +) +def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any): + data = generate_dummy_arrow_table() + labels = array_type(label_data, type=arrow_type) + dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params()) + dataset.construct() + + expected = np.array([0, 1, 0, 0, 1], dtype=np.float32) + assert_arrays_equal(expected, dataset.get_label()) + + +def test_dataset_construct_labels_fuzzy(): + arrow_table = generate_random_arrow_table(3, 1000, 42) + arrow_array = generate_random_arrow_array(1000, 42) + + arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array) + arrow_dataset.construct() + + pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy()) + pandas_dataset.construct() + + assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label()) From aeafccfbfb5c223d33b61ebe0f1e8b5592249151 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 7 Nov 2023 15:01:52 -0600 Subject: [PATCH 2/4] [python-package] fix access to Dataset metadata in scikit-learn custom metrics and objectives (#6108) --- python-package/lightgbm/basic.py | 68 +++++++++++++------ python-package/lightgbm/sklearn.py | 69 ++++++++++++++----- tests/python_package_test/test_basic.py | 90 ++++++++++++++++++++++++- tests/python_package_test/utils.py | 20 ++++++ 4 files changed, 209 insertions(+), 38 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 702c4682ea8d..e8d8bd84cbe7 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -434,7 +434,7 @@ def _data_to_2d_numpy( "It should be list of lists, numpy 2-D array or pandas DataFrame") -def _cfloat32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes float pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_float)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -442,7 +442,7 @@ def _cfloat32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray raise RuntimeError('Expected float pointer') -def _cfloat64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes double pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_double)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -450,7 +450,7 @@ def _cfloat64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray raise RuntimeError('Expected double pointer') -def _cint32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes int pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -458,7 +458,7 @@ def _cint32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: raise RuntimeError('Expected int32 pointer') -def _cint64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes int pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -1295,18 +1295,18 @@ def __create_sparse_native( data_indices_len = out_shape[0] indptr_len = out_shape[1] if indptr_type == _C_API_DTYPE_INT32: - out_indptr = _cint32_array_to_numpy(out_ptr_indptr, indptr_len) + out_indptr = _cint32_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len) elif indptr_type == _C_API_DTYPE_INT64: - out_indptr = _cint64_array_to_numpy(out_ptr_indptr, indptr_len) + out_indptr = _cint64_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len) else: raise TypeError("Expected int32 or int64 type for indptr") if data_type == _C_API_DTYPE_FLOAT32: - out_data = _cfloat32_array_to_numpy(out_ptr_data, data_indices_len) + out_data = _cfloat32_array_to_numpy(cptr=out_ptr_data, length=data_indices_len) elif data_type == _C_API_DTYPE_FLOAT64: - out_data = _cfloat64_array_to_numpy(out_ptr_data, data_indices_len) + out_data = _cfloat64_array_to_numpy(cptr=out_ptr_data, length=data_indices_len) else: raise TypeError("Expected float32 or float64 type for data") - out_indices = _cint32_array_to_numpy(out_ptr_indices, data_indices_len) + out_indices = _cint32_array_to_numpy(cptr=out_ptr_indices, length=data_indices_len) # break up indptr based on number of rows (note more than one matrix in multiclass case) per_class_indptr_shape = cs.indptr.shape[0] # for CSC there is extra column added @@ -2609,6 +2609,12 @@ def set_field( def get_field(self, field_name: str) -> Optional[np.ndarray]: """Get property from the Dataset. + Can only be run on a constructed Dataset. + + Unlike ``get_group()``, ``get_init_score()``, ``get_label()``, ``get_position()``, and ``get_weight()``, + this method ignores any raw data passed into ``lgb.Dataset()`` on the Python side, and will only read + data from the constructed C++ ``Dataset`` object. + Parameters ---------- field_name : str @@ -2635,11 +2641,20 @@ def get_field(self, field_name: str) -> Optional[np.ndarray]: if tmp_out_len.value == 0: return None if out_type.value == _C_API_DTYPE_INT32: - arr = _cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value) + arr = _cint32_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), + length=tmp_out_len.value + ) elif out_type.value == _C_API_DTYPE_FLOAT32: - arr = _cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value) + arr = _cfloat32_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), + length=tmp_out_len.value + ) elif out_type.value == _C_API_DTYPE_FLOAT64: - arr = _cfloat64_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), tmp_out_len.value) + arr = _cfloat64_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), + length=tmp_out_len.value + ) else: raise TypeError("Unknown type") if field_name == 'init_score': @@ -2878,6 +2893,10 @@ def set_group( if self._handle is not None and group is not None: group = _list_to_1d_numpy(group, dtype=np.int32, name='group') self.set_field('group', group) + # original values can be modified at cpp side + constructed_group = self.get_field('group') + if constructed_group is not None: + self.group = np.diff(constructed_group) return self def set_position( @@ -2941,37 +2960,40 @@ def get_feature_name(self) -> List[str]: ptr_string_buffers)) return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)] - def get_label(self) -> Optional[np.ndarray]: + def get_label(self) -> Optional[_LGBM_LabelType]: """Get the label of the Dataset. Returns ------- - label : numpy array or None + label : list, numpy 1-D array, pandas Series / one-column DataFrame or None The label information from the Dataset. + For a constructed ``Dataset``, this will only return a numpy array. """ if self.label is None: self.label = self.get_field('label') return self.label - def get_weight(self) -> Optional[np.ndarray]: + def get_weight(self) -> Optional[_LGBM_WeightType]: """Get the weight of the Dataset. Returns ------- - weight : numpy array or None + weight : list, numpy 1-D array, pandas Series or None Weight for each data point from the Dataset. Weights should be non-negative. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.weight is None: self.weight = self.get_field('weight') return self.weight - def get_init_score(self) -> Optional[np.ndarray]: + def get_init_score(self) -> Optional[_LGBM_InitScoreType]: """Get the initial score of the Dataset. Returns ------- - init_score : numpy array or None + init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None Init score of Booster. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.init_score is None: self.init_score = self.get_field('init_score') @@ -3009,17 +3031,18 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]: "set free_raw_data=False when construct Dataset to avoid this.") return self.data - def get_group(self) -> Optional[np.ndarray]: + def get_group(self) -> Optional[_LGBM_GroupType]: """Get the group of the Dataset. Returns ------- - group : numpy array or None + group : list, numpy 1-D array, pandas Series or None Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.group is None: self.group = self.get_field('group') @@ -3028,13 +3051,14 @@ def get_group(self) -> Optional[np.ndarray]: self.group = np.diff(self.group) return self.group - def get_position(self) -> Optional[np.ndarray]: + def get_position(self) -> Optional[_LGBM_PositionType]: """Get the position of the Dataset. Returns ------- - position : numpy 1-D array or None + position : numpy 1-D array, pandas Series or None Position of items used in unbiased learning-to-rank task. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.position is None: self.position = self.get_field('position') diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index c71c233df908..310d5d2ca6ea 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -86,6 +86,36 @@ _LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType] +def _get_group_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]: + group = dataset.get_group() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve query groups from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert (group is None or isinstance(group, np.ndarray)), error_msg + return group + + +def _get_label_from_constructed_dataset(dataset: Dataset) -> np.ndarray: + label = dataset.get_label() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve labels from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert isinstance(label, np.ndarray), error_msg + return label + + +def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]: + weight = dataset.get_weight() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve weights from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert (weight is None or isinstance(weight, np.ndarray)), error_msg + return weight + + class _ObjectiveFunctionWrapper: """Proxy class for objective function.""" @@ -151,17 +181,22 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. The value of the second order derivative (Hessian) of the loss with respect to the elements of preds for each sample point. """ - labels = dataset.get_label() + labels = _get_label_from_constructed_dataset(dataset) argc = len(signature(self.func).parameters) if argc == 2: grad, hess = self.func(labels, preds) # type: ignore[call-arg] - elif argc == 3: - grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] - elif argc == 4: - grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg] - else: - raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") - return grad, hess + return grad, hess + + weight = _get_weight_from_constructed_dataset(dataset) + if argc == 3: + grad, hess = self.func(labels, preds, weight) # type: ignore[call-arg] + return grad, hess + + if argc == 4: + group = _get_group_from_constructed_dataset(dataset) + return self.func(labels, preds, weight, group) # type: ignore[call-arg] + + raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") class _EvalFunctionWrapper: @@ -229,16 +264,20 @@ def __call__( is_higher_better : bool Is eval result higher better, e.g. AUC is ``is_higher_better``. """ - labels = dataset.get_label() + labels = _get_label_from_constructed_dataset(dataset) argc = len(signature(self.func).parameters) if argc == 2: return self.func(labels, preds) # type: ignore[call-arg] - elif argc == 3: - return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] - elif argc == 4: - return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg] - else: - raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") + + weight = _get_weight_from_constructed_dataset(dataset) + if argc == 3: + return self.func(labels, preds, weight) # type: ignore[call-arg] + + if argc == 4: + group = _get_group_from_constructed_dataset(dataset) + return self.func(labels, preds, weight, group) # type: ignore[call-arg] + + raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") # documentation templates for LGBMModel methods are shared between the classes in diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 7f8980c271f7..2f6b07e7a77f 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -15,7 +15,7 @@ import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series -from .utils import dummy_obj, load_breast_cancer, mse_obj +from .utils import dummy_obj, load_breast_cancer, mse_obj, np_assert_array_equal def test_basic(tmp_path): @@ -499,6 +499,94 @@ def check_asserts(data): check_asserts(lgb_data) +def test_dataset_construction_overwrites_user_provided_metadata_fields(): + + X = np.array([[1.0, 2.0], [3.0, 4.0]]) + + position = np.array([0.0, 1.0], dtype=np.float32) + if getenv('TASK', '') == 'cuda': + position = None + + dtrain = lgb.Dataset( + X, + params={ + "min_data_in_bin": 1, + "min_data_in_leaf": 1, + "verbosity": -1 + }, + group=[1, 1], + init_score=[0.312, 0.708], + label=[1, 2], + position=position, + weight=[0.5, 1.5], + ) + + # unconstructed, get_* methods should return whatever was provided + assert dtrain.group == [1, 1] + assert dtrain.get_group() == [1, 1] + assert dtrain.init_score == [0.312, 0.708] + assert dtrain.get_init_score() == [0.312, 0.708] + assert dtrain.label == [1, 2] + assert dtrain.get_label() == [1, 2] + if getenv('TASK', '') != 'cuda': + np_assert_array_equal( + dtrain.position, + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) + np_assert_array_equal( + dtrain.get_position(), + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) + assert dtrain.weight == [0.5, 1.5] + assert dtrain.get_weight() == [0.5, 1.5] + + # before construction, get_field() should raise an exception + for field_name in ["group", "init_score", "label", "position", "weight"]: + with pytest.raises(Exception, match=f"Cannot get {field_name} before construct Dataset"): + dtrain.get_field(field_name) + + # constructed, get_* methods should return numpy arrays, even when the provided + # input was a list of floats or ints + dtrain.construct() + expected_group = np.array([1, 1], dtype=np.int32) + np_assert_array_equal(dtrain.group, expected_group, strict=True) + np_assert_array_equal(dtrain.get_group(), expected_group, strict=True) + # get_field("group") returns a numpy array with boundaries, instead of size + np_assert_array_equal( + dtrain.get_field("group"), + np.array([0, 1, 2], dtype=np.int32), + strict=True + ) + + expected_init_score = np.array([0.312, 0.708],) + np_assert_array_equal(dtrain.init_score, expected_init_score, strict=True) + np_assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) + np_assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) + + expected_label = np.array([1, 2], dtype=np.float32) + np_assert_array_equal(dtrain.label, expected_label, strict=True) + np_assert_array_equal(dtrain.get_label(), expected_label, strict=True) + np_assert_array_equal(dtrain.get_field("label"), expected_label, strict=True) + + if getenv('TASK', '') != 'cuda': + expected_position = np.array([0.0, 1.0], dtype=np.float32) + np_assert_array_equal(dtrain.position, expected_position, strict=True) + np_assert_array_equal(dtrain.get_position(), expected_position, strict=True) + # NOTE: "position" is converted to int32 on the C++ side + np_assert_array_equal( + dtrain.get_field("position"), + np.array([0.0, 1.0], dtype=np.int32), + strict=True + ) + + expected_weight = np.array([0.5, 1.5], dtype=np.float32) + np_assert_array_equal(dtrain.weight, expected_weight, strict=True) + np_assert_array_equal(dtrain.get_weight(), expected_weight, strict=True) + np_assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) + + def test_choose_param_value(): original_params = { diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index df01e29852e7..7eae62b14369 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -1,6 +1,7 @@ # coding: utf-8 import pickle from functools import lru_cache +from inspect import getfullargspec import cloudpickle import joblib @@ -193,3 +194,22 @@ def pickle_and_unpickle_object(obj, serializer): serializer=serializer ) return obj_from_disk # noqa: RET504 + + +# doing this here, at import time, to ensure it only runs once_per import +# instead of once per assertion +_numpy_testing_supports_strict_kwarg = ( + "strict" in getfullargspec(np.testing.assert_array_equal).kwonlyargs +) + + +def np_assert_array_equal(*args, **kwargs): + """ + np.testing.assert_array_equal() only got the kwarg ``strict`` in June 2022: + https://github.com/numpy/numpy/pull/21595 + + This function is here for testing on older Python (and therefore ``numpy``) + """ + if not _numpy_testing_supports_strict_kwarg: + kwargs.pop("strict") + np.testing.assert_array_equal(*args, **kwargs) From 5e90255ee78bbee07d3a1afb01ffa22dfcfe9b6f Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 8 Nov 2023 13:22:37 -0600 Subject: [PATCH 3/4] [R-package] remove unreachable code (#6180) --- R-package/R/lgb.Booster.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 755b171724f9..2256a250b131 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -1462,7 +1462,6 @@ lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_ , toString(eval_names) , "]" )) - stop("lgb.get.eval.result: wrong eval name") } result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]] From 501e6e62fe0d505edb2578e8826e3da85a775aa8 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Thu, 9 Nov 2023 05:56:08 +0100 Subject: [PATCH 4/4] [python-package] Accept numpy generators as `random_state` (#6174) --- python-package/lightgbm/compat.py | 10 ++++++++++ python-package/lightgbm/dask.py | 6 +++--- python-package/lightgbm/sklearn.py | 10 +++++++--- tests/python_package_test/test_sklearn.py | 7 ++++--- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 82d048bb374e..984972ed1ae3 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -36,6 +36,16 @@ def __init__(self, *args, **kwargs): concat = None +"""numpy""" +try: + from numpy.random import Generator as np_random_Generator +except ImportError: + class np_random_Generator: # type: ignore + """Dummy class for np.random.Generator.""" + + def __init__(self, *args, **kwargs): + pass + """matplotlib""" try: import matplotlib # noqa: F401 diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 8aeeac09eed2..88e4779ee8ee 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1142,7 +1142,7 @@ def __init__( colsample_bytree: float = 1., reg_alpha: float = 0., reg_lambda: float = 0., - random_state: Optional[Union[int, np.random.RandomState]] = None, + random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, n_jobs: Optional[int] = None, importance_type: str = 'split', client: Optional[Client] = None, @@ -1347,7 +1347,7 @@ def __init__( colsample_bytree: float = 1., reg_alpha: float = 0., reg_lambda: float = 0., - random_state: Optional[Union[int, np.random.RandomState]] = None, + random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, n_jobs: Optional[int] = None, importance_type: str = 'split', client: Optional[Client] = None, @@ -1517,7 +1517,7 @@ def __init__( colsample_bytree: float = 1., reg_alpha: float = 0., reg_lambda: float = 0., - random_state: Optional[Union[int, np.random.RandomState]] = None, + random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, n_jobs: Optional[int] = None, importance_type: str = 'split', client: Optional[Client] = None, diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 310d5d2ca6ea..120a6667192c 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -15,7 +15,7 @@ from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase, _LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, - dt_DataTable, pd_DataFrame) + dt_DataTable, np_random_Generator, pd_DataFrame) from .engine import train __all__ = [ @@ -448,7 +448,7 @@ def __init__( colsample_bytree: float = 1., reg_alpha: float = 0., reg_lambda: float = 0., - random_state: Optional[Union[int, np.random.RandomState]] = None, + random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, n_jobs: Optional[int] = None, importance_type: str = 'split', **kwargs @@ -509,7 +509,7 @@ def __init__( random_state : int, RandomState object or None, optional (default=None) Random number seed. If int, this number is used to seed the C++ code. - If RandomState object (numpy), a random integer is picked based on its state to seed the C++ code. + If RandomState or Generator object (numpy), a random integer is picked based on its state to seed the C++ code. If None, default seeds in C++ code are used. n_jobs : int or None, optional (default=None) Number of parallel threads to use for training (can be changed at prediction time by @@ -710,6 +710,10 @@ def _process_params(self, stage: str) -> Dict[str, Any]: if isinstance(params['random_state'], np.random.RandomState): params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max) + elif isinstance(params['random_state'], np_random_Generator): + params['random_state'] = int( + params['random_state'].integers(np.iinfo(np.int32).max) + ) if self._n_classes > 2: for alias in _ConfigAliases.get('num_class'): params.pop(alias, None) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 2247c9a512d2..06b9ef18f9af 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -534,11 +534,12 @@ def test_non_serializable_objects_in_callbacks(tmp_path): assert gbm.booster_.attr_set_inside_callback == 40 -def test_random_state_object(): +@pytest.mark.parametrize("rng_constructor", [np.random.RandomState, np.random.default_rng]) +def test_random_state_object(rng_constructor): X, y = load_iris(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) - state1 = np.random.RandomState(123) - state2 = np.random.RandomState(123) + state1 = rng_constructor(123) + state2 = rng_constructor(123) clf1 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state1) clf2 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state2) # Test if random_state is properly stored