From 0af7a7cffa74591ddb0f3b7a6c4447ae160fd0b9 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 31 Oct 2023 00:03:08 +0100 Subject: [PATCH] [python-package] Allow to pass Arrow table as training data --- .ci/test-python-oldest.sh | 4 +- .ci/test.sh | 4 + .ci/test_windows.ps1 | 2 + include/LightGBM/arrow.h | 256 ++++++++++++++++++++++++ include/LightGBM/arrow.tpp | 190 ++++++++++++++++++ include/LightGBM/c_api.h | 18 ++ include/LightGBM/dataset.h | 32 +-- python-package/lightgbm/basic.py | 92 ++++++++- python-package/lightgbm/compat.py | 30 +++ python-package/pyproject.toml | 4 + src/c_api.cpp | 94 +++++++++ tests/cpp_tests/test_arrow.cpp | 210 +++++++++++++++++++ tests/python_package_test/test_arrow.py | 99 +++++++++ 13 files changed, 1017 insertions(+), 18 deletions(-) create mode 100644 include/LightGBM/arrow.h create mode 100644 include/LightGBM/arrow.tpp create mode 100644 tests/cpp_tests/test_arrow.cpp create mode 100644 tests/python_package_test/test_arrow.py diff --git a/.ci/test-python-oldest.sh b/.ci/test-python-oldest.sh index 3a0ea08dddda..40dfd393f1fe 100644 --- a/.ci/test-python-oldest.sh +++ b/.ci/test-python-oldest.sh @@ -7,9 +7,11 @@ # echo "installing lightgbm's dependencies" pip install \ + 'cffi==1.15.1' \ 'dataclasses' \ - 'numpy==1.12.0' \ + 'numpy==1.16.6' \ 'pandas==0.24.0' \ + 'pyarrow==6.0.1' \ 'scikit-learn==0.18.2' \ 'scipy==0.19.0' \ || exit -1 diff --git a/.ci/test.sh b/.ci/test.sh index 37fbd19152d6..9ffd48efe452 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -130,11 +130,13 @@ fi # including python=version[build=*cpython] to ensure that conda doesn't fall back to pypy mamba create -q -y -n $CONDA_ENV \ ${CONSTRAINED_DEPENDENCIES} \ + cffi \ cloudpickle \ joblib \ matplotlib \ numpy \ psutil \ + pyarrow \ pytest \ ${CONDA_PYTHON_REQUIREMENT} \ python-graphviz \ @@ -315,11 +317,13 @@ matplotlib.use\(\"Agg\"\)\ # importing the library should succeed even if all optional dependencies are not present conda uninstall --force --yes \ + cffi \ dask \ distributed \ joblib \ matplotlib \ psutil \ + pyarrow \ python-graphviz \ scikit-learn || exit -1 python -c "import lightgbm" || exit -1 diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index 413af821e065..6b02aed6ce8b 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -52,12 +52,14 @@ conda install brotlipy conda update -q -y conda conda create -q -y -n $env:CONDA_ENV ` + cffi ` cloudpickle ` joblib ` matplotlib ` numpy ` pandas ` psutil ` + pyarrow ` pytest ` "python=$env:PYTHON_VERSION[build=*cpython]" ` python-graphviz ` diff --git a/include/LightGBM/arrow.h b/include/LightGBM/arrow.h new file mode 100644 index 000000000000..3d1c74713bd3 --- /dev/null +++ b/include/LightGBM/arrow.h @@ -0,0 +1,256 @@ +/*! + * Copyright (c) 2023 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + * + * Author: Oliver Borchert + */ + +#ifndef LIGHTGBM_ARROW_H_ +#define LIGHTGBM_ARROW_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* -------------------------------------- C DATA INTERFACE ------------------------------------- */ +// The C data interface is taken from +// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions +// and is available under Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0). + +#ifdef __cplusplus +extern "C" { +#endif + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#ifdef __cplusplus +} +#endif + +/* --------------------------------------------------------------------------------------------- */ +/* CHUNKED ARRAY */ +/* --------------------------------------------------------------------------------------------- */ + +namespace LightGBM { + +/** + * @brief Arrow array-like container for a list of Arrow arrays. + */ +class ArrowChunkedArray { + /* List of length `n` for `n` chunks containing the individual Arrow arrays. */ + std::vector chunks_; + /* Schema for all chunks. */ + const ArrowSchema* schema_; + /* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */ + std::vector chunk_offsets_; + + inline void construct_chunk_offsets() { + chunk_offsets_.reserve(chunks_.size() + 1); + chunk_offsets_.emplace_back(0); + for (size_t k = 0; k < chunks_.size(); ++k) { + chunk_offsets_.emplace_back(chunks_[k]->length + chunk_offsets_.back()); + } + } + + public: + /** + * @brief Construct a new Arrow Chunked Array object. + * + * @param chunks A list with the chunks. + * @param schema The schema for all chunks. + */ + inline ArrowChunkedArray(std::vector chunks, const ArrowSchema* schema) { + chunks_ = chunks; + schema_ = schema; + construct_chunk_offsets(); + } + + /** + * @brief Construct a new Arrow Chunked Array object. + * + * @param n_chunks The number of chunks. + * @param chunks A C-style array containing the chunks. + * @param schema The schema for all chunks. + */ + inline ArrowChunkedArray(int64_t n_chunks, + const struct ArrowArray* chunks, + const struct ArrowSchema* schema) { + chunks_.reserve(n_chunks); + for (auto k = 0; k < n_chunks; ++k) { + chunks_.push_back(&chunks[k]); + } + schema_ = schema; + construct_chunk_offsets(); + } + + /** + * @brief Get the length of the chunked array. + * This method returns the cumulative length of all chunks. + * Complexity: O(1) + * + * @return int64_t The number of elements in the chunked array. + */ + inline int64_t get_length() const { return chunk_offsets_.back(); } + + /* ----------------------------------------- ITERATOR ---------------------------------------- */ + template + class Iterator { + using getter_fn = std::function; + + /* Reference to the chunked array that this iterator iterates over. */ + const ArrowChunkedArray& array_; + /* Function to fetch the value at a certain index from a single chunk. */ + getter_fn get_; + /* The chunk the iterator currently points to. */ + int64_t ptr_chunk_; + /* The index inside the current chunk that the iterator points to. */ + int64_t ptr_offset_; + + public: + using iterator_category = std::random_access_iterator_tag; + using difference_type = int64_t; + using value_type = T; + using pointer = value_type*; + using reference = value_type&; + + /** + * @brief Construct a new Iterator object. + * + * @param array Reference to the chunked array to iterator over. + * @param get Function to fetch the value at a certain index from a single chunk. + * @param ptr_chunk The index of the chunk to whose first index the iterator points to. + */ + Iterator(const ArrowChunkedArray& array, getter_fn get, int64_t ptr_chunk); + + T operator*() const; + template + T operator[](I idx) const; + + Iterator& operator++(); + Iterator& operator--(); + Iterator& operator+=(int64_t c); + + template + friend bool operator==(const Iterator& a, const Iterator& b); + template + friend bool operator!=(const Iterator& a, const Iterator& b); + template + friend int64_t operator-(const Iterator& a, const Iterator& b); + }; + + /** + * @brief Obtain an iterator to the beginning of the chunked array. + * + * @tparam T The value type of the iterator. May be any primitive type. + * @return Iterator The iterator. + */ + template + inline Iterator begin() const; + + /** + * @brief Obtain an iterator to the beginning of the chunked array. + * + * @tparam T The value type of the iterator. May be any primitive type. + * @return Iterator The iterator. + */ + template + inline Iterator end() const; + + template + friend int64_t operator-(const Iterator& a, const Iterator& b); +}; + +/** + * @brief Arrow container for a list of chunked arrays. + */ +class ArrowTable { + std::vector columns_; + + public: + /** + * @brief Construct a new Arrow Table object. + * + * @param n_chunks The number of chunks. + * @param chunks A C-style array containing the chunks. + * @param schema The schema for all chunks. + */ + inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema) { + columns_.reserve(schema->n_children); + for (int64_t j = 0; j < schema->n_children; ++j) { + std::vector children_chunks; + children_chunks.reserve(n_chunks); + for (int64_t k = 0; k < n_chunks; ++k) { + children_chunks.push_back(chunks[k].children[j]); + } + columns_.emplace_back(children_chunks, schema->children[j]); + } + } + + /** + * @brief Get the number of rows in the table. + * + * @return int64_t The number of rows. + */ + inline int64_t get_num_rows() const { return columns_.front().get_length(); } + + /** + * @brief Get the number of columns of this table. + * + * @return int64_t The column count. + */ + inline int64_t get_num_columns() const { return columns_.size(); } + + /** + * @brief Get the column at a particular index. + * + * @param idx The index of the column, must me in the range `[0, num_columns)`. + * @return const ArrowChunkedArray& The chunked array for the child at the provided index. + */ + inline const ArrowChunkedArray& get_column(size_t idx) const { return this->columns_[idx]; } +}; + +} // namespace LightGBM + +#include "arrow.tpp" + +#endif /* LIGHTGBM_ARROW_H_ */ diff --git a/include/LightGBM/arrow.tpp b/include/LightGBM/arrow.tpp new file mode 100644 index 000000000000..67b481c9497e --- /dev/null +++ b/include/LightGBM/arrow.tpp @@ -0,0 +1,190 @@ +#include + +#ifndef ARROW_TPP_ +#define ARROW_TPP_ + +namespace LightGBM { + +/** + * @brief Obtain a function to access an index from an Arrow array. + * + * @tparam T The return type of the function, must be a primitive type. + * @param dtype The Arrow format string describing the datatype of the Arrow array. + * @return std::function The index accessor function. + */ +template +std::function get_index_accessor(const char* dtype); + +/* ---------------------------------- ITERATOR INITIALIZATION ---------------------------------- */ + +template +inline ArrowChunkedArray::Iterator ArrowChunkedArray::begin() const { + return ArrowChunkedArray::Iterator(*this, get_index_accessor(schema_->format), 0); +} + +template +inline ArrowChunkedArray::Iterator ArrowChunkedArray::end() const { + return ArrowChunkedArray::Iterator(*this, get_index_accessor(schema_->format), + chunk_offsets_.size() - 1); +} + +/* ---------------------------------- ITERATOR IMPLEMENTATION ---------------------------------- */ + +template +ArrowChunkedArray::Iterator::Iterator(const ArrowChunkedArray& array, + getter_fn get, + int64_t ptr_chunk) + : array_(array), get_(get), ptr_chunk_(ptr_chunk) { + this->ptr_offset_ = 0; +} + +template +T ArrowChunkedArray::Iterator::operator*() const { + auto chunk = array_.chunks_[ptr_chunk_]; + return static_cast(get_(chunk, ptr_offset_)); +} + +template +template +T ArrowChunkedArray::Iterator::operator[](I idx) const { + auto it = std::lower_bound(array_.chunk_offsets_.begin(), array_.chunk_offsets_.end(), idx, + [](int64_t a, int64_t b) { return a <= b; }); + + auto chunk_idx = std::distance(array_.chunk_offsets_.begin() + 1, it); + auto chunk = array_.chunks_[chunk_idx]; + + auto ptr_offset = static_cast(idx) - array_.chunk_offsets_[chunk_idx]; + return static_cast(get_(chunk, ptr_offset)); +} + +template +ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator++() { + if (ptr_offset_ + 1 >= array_.chunks_[ptr_chunk_]->length) { + ptr_offset_ = 0; + ptr_chunk_++; + } else { + ptr_offset_++; + } + return *this; +} + +template +ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator--() { + if (ptr_offset_ == 0) { + ptr_chunk_--; + ptr_offset_ = array_.chunks_[ptr_chunk_]->length - 1; + } else { + ptr_chunk_--; + } + return *this; +} + +template +ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator+=(int64_t c) { + while (ptr_offset_ + c >= array_.chunks_[ptr_chunk_]->length) { + c -= array_.chunks_[ptr_chunk_]->length - ptr_offset_; + ptr_offset_ = 0; + ptr_chunk_++; + } + ptr_offset_ += c; + return *this; +} + +template +bool operator==(const ArrowChunkedArray::Iterator& a, const ArrowChunkedArray::Iterator& b) { + return a.ptr_chunk_ == b.ptr_chunk_ && a.ptr_offset_ == b.ptr_offset_; +} + +template +bool operator!=(const ArrowChunkedArray::Iterator& a, const ArrowChunkedArray::Iterator& b) { + return a.ptr_chunk_ != b.ptr_chunk_ || a.ptr_offset_ != b.ptr_offset_; +} + +template +int64_t operator-(const ArrowChunkedArray::Iterator& a, + const ArrowChunkedArray::Iterator& b) { + auto full_offset_a = a.array_.chunk_offsets_[a.ptr_chunk_] + a.ptr_offset_; + auto full_offset_b = b.array_.chunk_offsets_[b.ptr_chunk_] + b.ptr_offset_; + return full_offset_a - full_offset_b; +} + +/* --------------------------------------- INDEX ACCESSOR -------------------------------------- */ + +/** + * @brief The value of "no value" for a primitive type. + * + * @tparam T The type for which the missing value is defined. + * @return T The missing value. + */ +template +inline T arrow_primitive_missing_value() { + return 0; +} + +template <> +inline double arrow_primitive_missing_value() { + return std::numeric_limits::quiet_NaN(); +} + +template <> +inline float arrow_primitive_missing_value() { + return std::numeric_limits::quiet_NaN(); +} + +template +struct ArrayIndexAccessor { + V operator()(const ArrowArray* array, size_t idx) { + auto buffer_idx = idx + array->offset; + + // For primitive types, buffer at idx 0 provides validity, buffer at idx 1 data, see: + // https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout + auto validity = static_cast(array->buffers[0]); + + // Take return value from data buffer conditional on the validity of the index: + // - The structure of validity bitmasks is taken from here: + // https://arrow.apache.org/docs/format/Columnar.html#validity-bitmaps + // - If the bitmask is NULL, all indices are valid + if (validity == nullptr || !(validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) { + // In case the index is valid, we take it from the data buffer + auto data = static_cast(array->buffers[1]); + return static_cast(data[buffer_idx]); + } + + // In case the index is not valid, we return a default value + return arrow_primitive_missing_value(); + } +}; + +template +std::function get_index_accessor(const char* dtype) { + // Mapping obtained from: + // https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + switch (dtype[0]) { + case 'c': + return ArrayIndexAccessor(); + case 'C': + return ArrayIndexAccessor(); + case 's': + return ArrayIndexAccessor(); + case 'S': + return ArrayIndexAccessor(); + case 'i': + return ArrayIndexAccessor(); + case 'I': + return ArrayIndexAccessor(); + case 'l': + return ArrayIndexAccessor(); + case 'L': + return ArrayIndexAccessor(); + case 'f': + return ArrayIndexAccessor(); + case 'g': + return ArrayIndexAccessor(); + default: + throw std::invalid_argument("unsupported Arrow datatype"); + } +} + +} // namespace LightGBM + +#endif diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index bba46a02a492..01f55d5ddfb4 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -13,6 +13,7 @@ #ifndef LIGHTGBM_C_API_H_ #define LIGHTGBM_C_API_H_ +#include #include #ifdef __cplusplus @@ -437,6 +438,23 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMats(int32_t nmat, const DatasetHandle reference, DatasetHandle* out); +/*! + * \brief Create dataset from Arrow. + * \param n_chunks The number of Arrow arrays passed to this function + * \param chunks Pointer to the list of Arrow arrays + * \param schema Pointer to the schema of all Arrow arrays + * \param parameters Additional parameters + * \param reference Used to align bin mapper with other dataset, nullptr means isn't used + * \param[out] out Created dataset + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromArrow(int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema, + const char* parameters, + const DatasetHandle reference, + DatasetHandle *out); + /*! * \brief Create subset of a data. * \param handle Handle of full dataset diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index e7baa42dc2e6..e94e0d943a72 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -5,6 +5,7 @@ #ifndef LIGHTGBM_DATASET_H_ #define LIGHTGBM_DATASET_H_ +#include #include #include #include @@ -545,24 +546,29 @@ class Dataset { } } - inline void PushOneRow(int tid, data_size_t row_idx, const std::vector& feature_values) { - if (is_finish_load_) { return; } - for (size_t i = 0; i < feature_values.size() && i < static_cast(num_total_features_); ++i) { - int feature_idx = used_feature_map_[i]; - if (feature_idx >= 0) { - const int group = feature2group_[feature_idx]; - const int sub_feature = feature2subfeature_[feature_idx]; - feature_groups_[group]->PushData(tid, sub_feature, row_idx, feature_values[i]); - if (has_raw_) { - int feat_ind = numeric_feature_map_[feature_idx]; - if (feat_ind >= 0) { - raw_data_[feat_ind][row_idx] = static_cast(feature_values[i]); - } + inline void PushOneValue(int tid, data_size_t row_idx, size_t col_idx, double value) { + if (this->is_finish_load_) + return; + auto feature_idx = this->used_feature_map_[col_idx]; + if (feature_idx >= 0) { + auto group = this->feature2group_[feature_idx]; + auto sub_feature = this->feature2subfeature_[feature_idx]; + this->feature_groups_[group]->PushData(tid, sub_feature, row_idx, value); + if (this->has_raw_) { + auto feat_ind = numeric_feature_map_[feature_idx]; + if (feat_ind >= 0) { + raw_data_[feat_ind][row_idx] = static_cast(value); } } } } + inline void PushOneRow(int tid, data_size_t row_idx, const std::vector& feature_values) { + for (size_t i = 0; i < feature_values.size() && i < static_cast(num_total_features_); ++i) { + this->PushOneValue(tid, row_idx, i, feature_values[i]); + } + } + inline void PushOneRow(int tid, data_size_t row_idx, const std::vector>& feature_values) { if (is_finish_load_) { return; } std::vector is_feature_added(num_features_, false); diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 9b833afada84..d0d9f0b136f8 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -18,7 +18,8 @@ import numpy as np import scipy.sparse -from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series +from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat, + dt_DataTable, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series) from .libpath import find_lib_path if TYPE_CHECKING: @@ -90,7 +91,8 @@ scipy.sparse.spmatrix, "Sequence", List["Sequence"], - List[np.ndarray] + List[np.ndarray], + pa_Table ] _LGBM_LabelType = Union[ List[float], @@ -351,6 +353,59 @@ def _is_2d_collection(data: Any) -> bool: ) +def _is_pyarrow_table(data: Any) -> bool: + """Check whether data is a PyArrow table.""" + return isinstance(data, pa_Table) + + +class _ArrowCArray: + """Simple wrapper around the C representation of an Arrow type.""" + + n_chunks: int + chunks: arrow_cffi.CData + schema: arrow_cffi.CData + + def __init__(self, n_chunks: int, chunks: arrow_cffi.CData, schema: arrow_cffi.CData): + self.n_chunks = n_chunks + self.chunks = chunks + self.schema = schema + + @property + def chunks_ptr(self) -> int: + """Returns the address of the pointer to the list of chunks making up the array.""" + return int(arrow_cffi.cast("uintptr_t", arrow_cffi.addressof(self.chunks[0]))) + + @property + def schema_ptr(self) -> int: + """Returns the address of the pointer to the schema of the array.""" + return int(arrow_cffi.cast("uintptr_t", self.schema)) + + +def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray: + """Export an Arrow type to its C representation.""" + # Obtain objects to export + if isinstance(data, pa_Table): + export_objects = data.to_batches() + else: + raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow") + + # Prepare export + chunks = arrow_cffi.new("struct ArrowArray[]", len(export_objects)) + schema = arrow_cffi.new("struct ArrowSchema*") + + # Export all objects + for i, obj in enumerate(export_objects): + chunk_ptr = int(arrow_cffi.cast("uintptr_t", arrow_cffi.addressof(chunks[i]))) + if i == 0: + schema_ptr = int(arrow_cffi.cast("uintptr_t", schema)) + obj._export_to_c(chunk_ptr, schema_ptr) + else: + obj._export_to_c(chunk_ptr) + + return _ArrowCArray(len(chunks), chunks, schema) + + + def _data_to_2d_numpy( data: Any, dtype: "np.typing.DTypeLike", @@ -1562,7 +1617,7 @@ def __init__( Parameters ---------- - data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array + data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table Data source of Dataset. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) @@ -1581,7 +1636,7 @@ def __init__( Init score for Dataset. feature_name : list of str, or 'auto', optional (default="auto") Feature names. - If 'auto' and data is pandas DataFrame, data columns names are used. + If 'auto' and data is pandas DataFrame or pyarrow Table, data columns names are used. categorical_feature : list of str or int, or 'auto', optional (default="auto") Categorical features. If list of int, interpreted as indices. @@ -1938,6 +1993,9 @@ def _lazy_init( self.__init_from_csc(data, params_str, ref_dataset) elif isinstance(data, np.ndarray): self.__init_from_np2d(data, params_str, ref_dataset) + elif _is_pyarrow_table(data): + self.__init_from_pyarrow_table(data, params_str, ref_dataset) + feature_name = data.column_names elif isinstance(data, list) and len(data) > 0: if _is_list_of_numpy_arrays(data): self.__init_from_list_np2d(data, params_str, ref_dataset) @@ -2198,6 +2256,32 @@ def __init_from_csc( ctypes.byref(self._handle))) return self + def __init_from_pyarrow_table( + self, + table: pa_Table, + params_str: str, + ref_dataset: Optional[_DatasetHandle] + ) -> "Dataset": + """Initialize data from a PyArrow table.""" + if not PYARROW_INSTALLED: + raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.") + + # Check that the input is valid: we only handle numbers (for now) + if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types): + raise ValueError("Arrow table may only have integer or floating point datatypes") + + # Export Arrow table to C + c_array = _export_arrow_to_c(table) + self._handle = ctypes.c_void_p() + _safe_call(_LIB.LGBM_DatasetCreateFromArrow( + ctypes.c_int64(c_array.n_chunks), + ctypes.c_void_p(c_array.chunks_ptr), + ctypes.c_void_p(c_array.schema_ptr), + _c_str(params_str), + ref_dataset, + ctypes.byref(self._handle))) + return self + @staticmethod def _compare_params_for_warning( params: Dict[str, Any], diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 0a55ccd1e421..7be375e02e85 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -185,6 +185,36 @@ class dask_Series: # type: ignore def __init__(self, *args, **kwargs): pass +"""pyarrow""" +try: + from pyarrow import Table as pa_Table + from pyarrow.cffi import ffi as arrow_cffi + from pyarrow.types import is_floating as arrow_is_floating + from pyarrow.types import is_integer as arrow_is_integer + PYARROW_INSTALLED = True +except ImportError: + PYARROW_INSTALLED = False + + class pa_Table: # type: ignore + """Dummy class for pa.Table.""" + + def __init__(self, *args, **kwargs): + pass + + class arrow_cffi: # type: ignore + """Dummy class for pyarrow.cffi.ffi.""" + + CData = None + addressof = None + cast = None + new = None + + def __init__(self, *args, **kwargs): + pass + + arrow_is_integer = None + arrow_is_floating = None + """cpu_count()""" try: from joblib import cpu_count diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 6e43dc242d1b..83520c5248cd 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -33,6 +33,10 @@ requires-python = ">=3.6" version = "4.1.0.99" [project.optional-dependencies] +arrow = [ + "cffi>=1.15.1", + "pyarrow>=6.0.1" +] dask = [ "dask[array,dataframe,distributed]>=2.0.0", "pandas>=0.24.0" diff --git a/src/c_api.cpp b/src/c_api.cpp index 5c98d7d24c01..6467bb54a8fe 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -4,6 +4,7 @@ */ #include +#include #include #include #include @@ -832,6 +833,7 @@ class Booster { // explicitly declare symbols from LightGBM namespace using LightGBM::AllgatherFunction; +using LightGBM::ArrowTable; using LightGBM::Booster; using LightGBM::Common::CheckElementsIntervalClosed; using LightGBM::Common::RemoveQuotationSymbol; @@ -1567,6 +1569,98 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, API_END(); } +int LGBM_DatasetCreateFromArrow(int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema, + const char* parameters, + const DatasetHandle reference, + DatasetHandle *out) { + API_BEGIN(); + + auto param = Config::Str2Map(parameters); + Config config; + config.Set(param); + OMP_SET_NUM_THREADS(config.num_threads); + + std::unique_ptr ret; + + // Prepare the Arrow data + ArrowTable table(n_chunks, chunks, schema); + + // Initialize the dataset + if (reference == nullptr) { + // If there is no reference dataset, we first sample indices + auto sample_indices = CreateSampleIndices(static_cast(table.get_num_rows()), config); + auto sample_count = static_cast(sample_indices.size()); + std::vector> sample_values(table.get_num_columns()); + std::vector> sample_idx(table.get_num_columns()); + + // Then, we obtain sample values by parallelizing across columns + OMP_INIT_EX(); + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) + for (int64_t j = 0; j < table.get_num_columns(); ++j) { + OMP_LOOP_EX_BEGIN(); + + // Values need to be copied from the record batches. + sample_values[j].reserve(sample_indices.size()); + sample_idx[j].reserve(sample_indices.size()); + + // The chunks are iterated over in the inner loop as columns can be treated independently. + int last_idx = 0; + int i = 0; + auto it = table.get_column(j).begin(); + for (auto idx : sample_indices) { + std::advance(it, idx - last_idx); + auto v = *it; + if (std::fabs(v) > kZeroThreshold || std::isnan(v)) { + sample_values[j].emplace_back(v); + sample_idx[j].emplace_back(i); + } + last_idx = idx; + i++; + } + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + + // Finally, we initialize a loader from the sampled values + DatasetLoader loader(config, nullptr, 1, nullptr); + ret.reset(loader.ConstructFromSampleData(Vector2Ptr(&sample_values).data(), + Vector2Ptr(&sample_idx).data(), + table.get_num_columns(), + VectorSize(sample_values).data(), + sample_count, + table.get_num_rows(), + table.get_num_rows())); + } else { + ret.reset(new Dataset(static_cast(table.get_num_rows()))); + ret->CreateValid(reinterpret_cast(reference)); + if (ret->has_raw()) { + ret->ResizeRaw(static_cast(table.get_num_rows())); + } + } + + // After sampling and properly initializing all bins, we can add our data to the dataset. Here, + // we parallelize across rows. + OMP_INIT_EX(); + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) + for (int64_t j = 0; j < table.get_num_columns(); ++j) { + OMP_LOOP_EX_BEGIN(); + const int tid = omp_get_thread_num(); + data_size_t idx = 0; + auto column = table.get_column(j); + for (auto it = column.begin(), end = column.end(); it != end; ++it) { + ret->PushOneValue(tid, idx++, j, *it); + } + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + + ret->FinishLoad(); + *out = ret.release(); + API_END(); +} + int LGBM_DatasetGetSubset( const DatasetHandle handle, const int32_t* used_row_indices, diff --git a/tests/cpp_tests/test_arrow.cpp b/tests/cpp_tests/test_arrow.cpp new file mode 100644 index 000000000000..7e3c57c401f4 --- /dev/null +++ b/tests/cpp_tests/test_arrow.cpp @@ -0,0 +1,210 @@ +/*! + * Copyright (c) 2023 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + * + * Author: Oliver Borchert + */ + +#include +#include + +#include +#include + +using LightGBM::ArrowChunkedArray; +using LightGBM::ArrowTable; + +class ArrowChunkedArrayTest : public testing::Test { + protected: + void SetUp() override {} + + ArrowArray created_nested_array(const std::vector& arrays) { + ArrowArray arr; + arr.buffers = nullptr; + arr.children = (ArrowArray**)arrays.data(); // NOLINT + arr.dictionary = nullptr; + arr.length = arrays[0]->length; + arr.n_buffers = 0; + arr.n_children = arrays.size(); + arr.null_count = 0; + arr.offset = 0; + arr.private_data = nullptr; + arr.release = nullptr; + return arr; + } + + template + ArrowArray create_primitive_array(const std::vector& values, + int64_t offset = 0, + std::vector null_indices = {}) { + // NOTE: Arrow arrays have 64-bit alignment but we can safely ignore this in tests + // 1) Create validity bitmap + char* validity = nullptr; + if (!null_indices.empty()) { + validity = static_cast(calloc(values.size() + sizeof(char) - 1, sizeof(char))); + for (size_t i = 0; i < values.size(); ++i) { + if (std::find(null_indices.begin(), null_indices.end(), i) != null_indices.end()) { + validity[i / 8] |= (1 << (i % 8)); + } + } + } + + // 2) Create buffers + const void** buffers = (const void**)malloc(sizeof(void*) * 2); + buffers[0] = validity; + buffers[1] = values.data() + offset; + + // Create arrow array + ArrowArray arr; + arr.buffers = buffers; + arr.children = nullptr; + arr.dictionary = nullptr; + arr.length = values.size() - offset; + arr.null_count = 0; + arr.offset = 0; + arr.private_data = nullptr; + arr.release = [](ArrowArray* arr) { + if (arr->buffers[0] != nullptr) + free((void*)(arr->buffers[0])); // NOLINT + free((void*)(arr->buffers)); // NOLINT + }; + return arr; + } + + ArrowSchema create_nested_schema(const std::vector& arrays) { + ArrowSchema schema; + schema.format = "+s"; + schema.name = nullptr; + schema.metadata = nullptr; + schema.flags = 0; + schema.n_children = arrays.size(); + schema.children = (ArrowSchema**)arrays.data(); // NOLINT + schema.dictionary = nullptr; + schema.private_data = nullptr; + schema.release = nullptr; + return schema; + } + + template + ArrowSchema create_primitive_schema() { + std::logic_error("not implemented"); + } + + template <> + ArrowSchema create_primitive_schema() { + ArrowSchema schema; + schema.format = "f"; + schema.name = nullptr; + schema.metadata = nullptr; + schema.flags = 0; + schema.n_children = 0; + schema.children = nullptr; + schema.dictionary = nullptr; + schema.private_data = nullptr; + schema.release = nullptr; + return schema; + } +}; + +TEST_F(ArrowChunkedArrayTest, GetLength) { + std::vector dat1 = {1, 2}; + auto arr1 = create_primitive_array(dat1); + + ArrowChunkedArray ca1(1, &arr1, nullptr); + ASSERT_EQ(ca1.get_length(), 2); + + std::vector dat2 = {3, 4, 5, 6}; + auto arr2 = create_primitive_array(dat2); + ArrowArray arrs[2] = {arr1, arr2}; + ArrowChunkedArray ca2(2, arrs, nullptr); + ASSERT_EQ(ca2.get_length(), 6); + + arr1.release(&arr1); + arr2.release(&arr2); +} + +TEST_F(ArrowChunkedArrayTest, GetColumns) { + std::vector dat1 = {1, 2, 3}; + auto arr1 = create_primitive_array(dat1); + std::vector dat2 = {4, 5, 6}; + auto arr2 = create_primitive_array(dat2); + std::vector arrs = {&arr1, &arr2}; + auto arr = created_nested_array(arrs); + + auto schema1 = create_primitive_schema(); + auto schema2 = create_primitive_schema(); + std::vector schemas = {&schema1, &schema2}; + auto schema = create_nested_schema(schemas); + + ArrowTable table(1, &arr, &schema); + ASSERT_EQ(table.get_num_rows(), 3); + ASSERT_EQ(table.get_num_columns(), 2); + + auto ca1 = table.get_column(0); + ASSERT_EQ(ca1.get_length(), 3); + ASSERT_EQ(*ca1.begin(), 1); + + auto ca2 = table.get_column(1); + ASSERT_EQ(ca2.get_length(), 3); + ASSERT_EQ(*ca2.begin(), 4); + + arr1.release(&arr1); + arr2.release(&arr2); +} + +TEST_F(ArrowChunkedArrayTest, IteratorArithmetic) { + std::vector dat1 = {1, 2}; + auto arr1 = create_primitive_array(dat1); + std::vector dat2 = {3, 4, 5, 6}; + auto arr2 = create_primitive_array(dat2); + std::vector dat3 = {7}; + auto arr3 = create_primitive_array(dat3); + auto schema = create_primitive_schema(); + + ArrowArray arrs[3] = {arr1, arr2, arr3}; + ArrowChunkedArray ca(3, arrs, &schema); + + // Arithmetic + auto it = ca.begin(); + ASSERT_EQ(*it, 1); + ++it; + ASSERT_EQ(*it, 2); + ++it; + ASSERT_EQ(*it, 3); + it += 2; + ASSERT_EQ(*it, 5); + it += 2; + ASSERT_EQ(*it, 7); + --it; + ASSERT_EQ(*it, 6); + + // Subscripts + ASSERT_EQ(it[0], 1); + ASSERT_EQ(it[1], 2); + ASSERT_EQ(it[2], 3); + ASSERT_EQ(it[6], 7); + + // End + auto end = ca.end(); + ASSERT_EQ(end - it, 2); + ASSERT_EQ(end - ca.begin(), 7); + + arr1.release(&arr1); + arr2.release(&arr2); + arr2.release(&arr3); +} + +TEST_F(ArrowChunkedArrayTest, OffsetAndValidity) { + std::vector dat = {0, 1, 2, 3, 4, 5, 6}; + auto arr = create_primitive_array(dat, 2, {0, 1}); + auto schema = create_primitive_schema(); + ArrowChunkedArray ca(1, &arr, &schema); + + auto it = ca.begin(); + ASSERT_TRUE(std::isnan(*it)); + ASSERT_TRUE(std::isnan(*(++it))); + ASSERT_EQ(it[2], 4); + ASSERT_EQ(it[4], 6); + + arr.release(&arr); +} diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py new file mode 100644 index 000000000000..54ca945e1e53 --- /dev/null +++ b/tests/python_package_test/test_arrow.py @@ -0,0 +1,99 @@ +# coding: utf-8 +import filecmp +from pathlib import Path +from typing import Any, Callable, Dict + +import numpy as np +import pyarrow as pa +import pytest + +import lightgbm as lgb + +# ----------------------------------------------------------------------------------------------- # +# UTILITIES # +# ----------------------------------------------------------------------------------------------- # + + +def generate_simple_arrow_table() -> pa.Table: + columns = [ + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint8()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int8()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint16()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int16()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint32()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int32()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint64()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int64()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float32()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float64()), + ] + return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))]) + + +def generate_dummy_arrow_table() -> pa.Table: + col1 = pa.chunked_array([[1, 2, 3], [4, 5]], type=pa.uint8()) + col2 = pa.chunked_array([[0.5, 0.6], [0.1, 0.8, 1.5]], type=pa.float32()) + return pa.Table.from_arrays([col1, col2], names=["a", "b"]) + + +def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table: + columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)] + names = [f"col_{i}" for i in range(num_columns)] + return pa.Table.from_arrays(columns, names=names) + + +def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray: + generator = np.random.default_rng(seed) + data = generator.standard_normal(num_datapoints) + + # Set random nulls + indices = generator.choice(len(data), size=num_datapoints // 10) + data[indices] = None + + # Split data into <=2 random chunks + split_points = np.sort(generator.choice(np.arange(1, num_datapoints), 2, replace=False)) + split_points = np.concatenate([[0], split_points, [num_datapoints]]) + chunks = [data[split_points[i] : split_points[i + 1]] for i in range(len(split_points) - 1)] + chunks = [chunk for chunk in chunks if len(chunk) > 0] + + # Turn chunks into array + return pa.chunked_array([data], type=pa.float32()) + + +def dummy_dataset_params() -> Dict[str, Any]: + return { + "min_data_in_bin": 1, + "min_data_in_leaf": 1, + } + + +# ----------------------------------------------------------------------------------------------- # +# UNIT TESTS # +# ----------------------------------------------------------------------------------------------- # + +# ------------------------------------------- DATASET ------------------------------------------- # + + +@pytest.mark.parametrize( + ("arrow_table_fn", "dataset_params"), + [ # Use lambda functions here to minimize memory consumption + (lambda: generate_simple_arrow_table(), dummy_dataset_params()), + (lambda: generate_dummy_arrow_table(), dummy_dataset_params()), + (lambda: generate_random_arrow_table(3, 1000, 42), {}), + (lambda: generate_random_arrow_table(100, 10000, 43), {}), + ], +) +def test_dataset_construct_fuzzy( + tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any] +): + arrow_table = arrow_table_fn() + + arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params) + arrow_dataset.construct() + + pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), params=dataset_params) + pandas_dataset.construct() + + arrow_dataset._dump_text(tmp_path / "arrow.txt") + pandas_dataset._dump_text(tmp_path / "pandas.txt") + assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")