From 516bde95015b05e57ff41b19d9bec19b0c48d7e6 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Wed, 22 Nov 2023 22:50:31 +0100 Subject: [PATCH] [python-package] Allow to pass Arrow array as groups (#6166) --- include/LightGBM/c_api.h | 3 +- include/LightGBM/dataset.h | 4 ++ python-package/lightgbm/basic.py | 15 +++-- src/io/dataset.cpp | 2 + src/io/metadata.cpp | 28 ++++++--- tests/python_package_test/test_arrow.py | 77 +++++++++++++++++-------- 6 files changed, 89 insertions(+), 40 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index fd337cbc7cbe..eafe6fab7825 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -558,9 +558,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, /*! * \brief Set vector to a content in info. * \note + * - \a group converts input datatype into ``int32``; * - \a label and \a weight convert input datatype into ``float32``. * \param handle Handle of dataset - * \param field_name Field name, can be \a label, \a weight + * \param field_name Field name, can be \a label, \a weight, \a group * \param n_chunks The number of Arrow arrays passed to this function * \param chunks Pointer to the list of Arrow arrays * \param schema Pointer to the schema of all Arrow arrays diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 48c1bee804d7..bf8264276a5f 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -116,6 +116,7 @@ class Metadata { void SetWeights(const ArrowChunkedArray& array); void SetQuery(const data_size_t* query, data_size_t len); + void SetQuery(const ArrowChunkedArray& array); void SetPosition(const data_size_t* position, data_size_t len); @@ -348,6 +349,9 @@ class Metadata { void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size); /*! \brief Insert queries at the given index */ void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len); + /*! \brief Set queries from pointers to the first element and the end of an iterator. */ + template + void SetQueriesFromIterator(It first, It last); /*! \brief Filename of current data */ std::string data_filename_; /*! \brief Number of data */ diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 008ff1727d78..b55546941f77 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -70,7 +70,9 @@ List[float], List[int], np.ndarray, - pd_Series + pd_Series, + pa_Array, + pa_ChunkedArray, ] _LGBM_PositionType = Union[ np.ndarray, @@ -1652,7 +1654,7 @@ def __init__( If this is Dataset for validation, training data should be used as reference. weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Weight for each instance. Weights should be non-negative. - group : list, numpy 1-D array, pandas Series or None, optional (default=None) + group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. @@ -2432,7 +2434,7 @@ def create_valid( Label of the data. weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Weight for each instance. Weights should be non-negative. - group : list, numpy 1-D array, pandas Series or None, optional (default=None) + group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. @@ -2889,7 +2891,7 @@ def set_group( Parameters ---------- - group : list, numpy 1-D array, pandas Series or None + group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. @@ -2903,7 +2905,8 @@ def set_group( """ self.group = group if self._handle is not None and group is not None: - group = _list_to_1d_numpy(group, dtype=np.int32, name='group') + if not _is_pyarrow_array(group): + group = _list_to_1d_numpy(group, dtype=np.int32, name='group') self.set_field('group', group) # original values can be modified at cpp side constructed_group = self.get_field('group') @@ -4431,7 +4434,7 @@ def refit( .. versionadded:: 4.0.0 - group : list, numpy 1-D array, pandas Series or None, optional (default=None) + group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Group/query size for ``data``. Only used in the learning-to-rank task. sum(group) = n_samples. diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 01eb41b71367..78dd5e4319a5 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -904,6 +904,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray metadata_.SetLabel(ca); } else if (name == std::string("weight") || name == std::string("weights")) { metadata_.SetWeights(ca); + } else if (name == std::string("query") || name == std::string("group")) { + metadata_.SetQuery(ca); } else { return false; } diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index ed4fb135e62a..d94b0ed3f2f7 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -507,30 +507,34 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da // CUDA is handled after all insertions are complete } -void Metadata::SetQuery(const data_size_t* query, data_size_t len) { +template +void Metadata::SetQueriesFromIterator(It first, It last) { std::lock_guard lock(mutex_); - // save to nullptr - if (query == nullptr || len == 0) { + // Clear query boundaries on empty input + if (last - first == 0) { query_boundaries_.clear(); num_queries_ = 0; return; } + data_size_t sum = 0; #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum) - for (data_size_t i = 0; i < len; ++i) { - sum += query[i]; + for (data_size_t i = 0; i < last - first; ++i) { + sum += first[i]; } if (num_data_ != sum) { - Log::Fatal("Sum of query counts is not same with #data"); + Log::Fatal("Sum of query counts (%i) differs from the length of #data (%i)", num_data_, sum); } - num_queries_ = len; + num_queries_ = last - first; + query_boundaries_.resize(num_queries_ + 1); query_boundaries_[0] = 0; for (data_size_t i = 0; i < num_queries_; ++i) { - query_boundaries_[i + 1] = query_boundaries_[i] + query[i]; + query_boundaries_[i + 1] = query_boundaries_[i] + first[i]; } CalculateQueryWeights(); query_load_from_file_ = false; + #ifdef USE_CUDA if (cuda_metadata_ != nullptr) { if (query_weights_.size() > 0) { @@ -543,6 +547,14 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { #endif // USE_CUDA } +void Metadata::SetQuery(const data_size_t* query, data_size_t len) { + SetQueriesFromIterator(query, query + len); +} + +void Metadata::SetQuery(const ArrowChunkedArray& array) { + SetQueriesFromIterator(array.begin(), array.end()); +} + void Metadata::SetPosition(const data_size_t* positions, data_size_t len) { std::lock_guard lock(mutex_); // save to nullptr diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 40482a904a62..38b053e94fd5 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -1,7 +1,6 @@ # coding: utf-8 import filecmp -from pathlib import Path -from typing import Any, Callable, Dict +from typing import Any, Dict import numpy as np import pyarrow as pa @@ -15,6 +14,21 @@ # UTILITIES # # ----------------------------------------------------------------------------------------------- # +_INTEGER_TYPES = [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.uint8(), + pa.uint16(), + pa.uint32(), + pa.uint64(), +] +_FLOAT_TYPES = [ + pa.float32(), + pa.float64(), +] + def generate_simple_arrow_table() -> pa.Table: columns = [ @@ -85,9 +99,7 @@ def dummy_dataset_params() -> Dict[str, Any]: (lambda: generate_random_arrow_table(100, 10000, 43), {}), ], ) -def test_dataset_construct_fuzzy( - tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any] -): +def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params): arrow_table = arrow_table_fn() arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params) @@ -108,17 +120,23 @@ def test_dataset_construct_fields_fuzzy(): arrow_table = generate_random_arrow_table(3, 1000, 42) arrow_labels = generate_random_arrow_array(1000, 42) arrow_weights = generate_random_arrow_array(1000, 42) + arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32()) - arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights) + arrow_dataset = lgb.Dataset( + arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups + ) arrow_dataset.construct() pandas_dataset = lgb.Dataset( - arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy() + arrow_table.to_pandas(), + label=arrow_labels.to_numpy(), + weight=arrow_weights.to_numpy(), + group=arrow_groups.to_numpy(), ) pandas_dataset.construct() # Check for equality - for field in ("label", "weight"): + for field in ("label", "weight", "group"): np_assert_array_equal( arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True ) @@ -133,22 +151,8 @@ def test_dataset_construct_fields_fuzzy(): ["array_type", "label_data"], [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])], ) -@pytest.mark.parametrize( - "arrow_type", - [ - pa.int8(), - pa.int16(), - pa.int32(), - pa.int64(), - pa.uint8(), - pa.uint16(), - pa.uint32(), - pa.uint64(), - pa.float32(), - pa.float64(), - ], -) -def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any): +@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES) +def test_dataset_construct_labels(array_type, label_data, arrow_type): data = generate_dummy_arrow_table() labels = array_type(label_data, type=arrow_type) dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params()) @@ -175,7 +179,7 @@ def test_dataset_construct_weights_none(): [(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])], ) @pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()]) -def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any): +def test_dataset_construct_weights(array_type, weight_data, arrow_type): data = generate_dummy_arrow_table() weights = array_type(weight_data, type=arrow_type) dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params()) @@ -183,3 +187,26 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32) np_assert_array_equal(expected, dataset.get_weight(), strict=True) + + +# -------------------------------------------- GROUPS ------------------------------------------- # + + +@pytest.mark.parametrize( + ["array_type", "group_data"], + [ + (pa.array, [2, 3]), + (pa.chunked_array, [[2], [3]]), + (pa.chunked_array, [[], [2, 3]]), + (pa.chunked_array, [[2], [], [3], []]), + ], +) +@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES) +def test_dataset_construct_groups(array_type, group_data, arrow_type): + data = generate_dummy_arrow_table() + groups = array_type(group_data, type=arrow_type) + dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params()) + dataset.construct() + + expected = np.array([0, 2, 5], dtype=np.int32) + np_assert_array_equal(expected, dataset.get_field("group"), strict=True)