Skip to content

Commit

Permalink
[python-package] Allow to pass Arrow table as training data (microsof…
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero authored and david-cortes committed Nov 8, 2023
1 parent 88e1d72 commit 906a2bf
Show file tree
Hide file tree
Showing 13 changed files with 1,017 additions and 18 deletions.
4 changes: 3 additions & 1 deletion .ci/test-python-oldest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
#
echo "installing lightgbm's dependencies"
pip install \
'cffi==1.15.1' \
'dataclasses' \
'numpy==1.12.0' \
'numpy==1.16.6' \
'pandas==0.24.0' \
'pyarrow==6.0.1' \
'scikit-learn==0.18.2' \
'scipy==0.19.0' \
|| exit -1
Expand Down
4 changes: 4 additions & 0 deletions .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ fi
# including python=version[build=*cpython] to ensure that conda doesn't fall back to pypy
mamba create -q -y -n $CONDA_ENV \
${CONSTRAINED_DEPENDENCIES} \
cffi \
cloudpickle \
joblib \
matplotlib \
numpy \
psutil \
pyarrow \
pytest \
${CONDA_PYTHON_REQUIREMENT} \
python-graphviz \
Expand Down Expand Up @@ -315,11 +317,13 @@ matplotlib.use\(\"Agg\"\)\

# importing the library should succeed even if all optional dependencies are not present
conda uninstall --force --yes \
cffi \
dask \
distributed \
joblib \
matplotlib \
psutil \
pyarrow \
python-graphviz \
scikit-learn || exit -1
python -c "import lightgbm" || exit -1
Expand Down
2 changes: 2 additions & 0 deletions .ci/test_windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ conda install brotlipy

conda update -q -y conda
conda create -q -y -n $env:CONDA_ENV `
cffi `
cloudpickle `
joblib `
matplotlib `
numpy `
pandas `
psutil `
pyarrow `
pytest `
"python=$env:PYTHON_VERSION[build=*cpython]" `
python-graphviz `
Expand Down
256 changes: 256 additions & 0 deletions include/LightGBM/arrow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/*!
* Copyright (c) 2023 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*
* Author: Oliver Borchert
*/

#ifndef LIGHTGBM_ARROW_H_
#define LIGHTGBM_ARROW_H_

#include <algorithm>
#include <cstdint>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <utility>
#include <vector>

/* -------------------------------------- C DATA INTERFACE ------------------------------------- */
// The C data interface is taken from
// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions
// and is available under Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0).

#ifdef __cplusplus
extern "C" {
#endif

#define ARROW_FLAG_DICTIONARY_ORDERED 1
#define ARROW_FLAG_NULLABLE 2
#define ARROW_FLAG_MAP_KEYS_SORTED 4

struct ArrowSchema {
// Array type description
const char* format;
const char* name;
const char* metadata;
int64_t flags;
int64_t n_children;
struct ArrowSchema** children;
struct ArrowSchema* dictionary;

// Release callback
void (*release)(struct ArrowSchema*);
// Opaque producer-specific data
void* private_data;
};

struct ArrowArray {
// Array data description
int64_t length;
int64_t null_count;
int64_t offset;
int64_t n_buffers;
int64_t n_children;
const void** buffers;
struct ArrowArray** children;
struct ArrowArray* dictionary;

// Release callback
void (*release)(struct ArrowArray*);
// Opaque producer-specific data
void* private_data;
};

#ifdef __cplusplus
}
#endif

/* --------------------------------------------------------------------------------------------- */
/* CHUNKED ARRAY */
/* --------------------------------------------------------------------------------------------- */

namespace LightGBM {

/**
* @brief Arrow array-like container for a list of Arrow arrays.
*/
class ArrowChunkedArray {
/* List of length `n` for `n` chunks containing the individual Arrow arrays. */
std::vector<const ArrowArray*> chunks_;
/* Schema for all chunks. */
const ArrowSchema* schema_;
/* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */
std::vector<int64_t> chunk_offsets_;

inline void construct_chunk_offsets() {
chunk_offsets_.reserve(chunks_.size() + 1);
chunk_offsets_.emplace_back(0);
for (size_t k = 0; k < chunks_.size(); ++k) {
chunk_offsets_.emplace_back(chunks_[k]->length + chunk_offsets_.back());
}
}

public:
/**
* @brief Construct a new Arrow Chunked Array object.
*
* @param chunks A list with the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema) {
chunks_ = chunks;
schema_ = schema;
construct_chunk_offsets();
}

/**
* @brief Construct a new Arrow Chunked Array object.
*
* @param n_chunks The number of chunks.
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(int64_t n_chunks,
const struct ArrowArray* chunks,
const struct ArrowSchema* schema) {
chunks_.reserve(n_chunks);
for (auto k = 0; k < n_chunks; ++k) {
chunks_.push_back(&chunks[k]);
}
schema_ = schema;
construct_chunk_offsets();
}

/**
* @brief Get the length of the chunked array.
* This method returns the cumulative length of all chunks.
* Complexity: O(1)
*
* @return int64_t The number of elements in the chunked array.
*/
inline int64_t get_length() const { return chunk_offsets_.back(); }

/* ----------------------------------------- ITERATOR ---------------------------------------- */
template <typename T>
class Iterator {
using getter_fn = std::function<T(const ArrowArray*, int64_t)>;

/* Reference to the chunked array that this iterator iterates over. */
const ArrowChunkedArray& array_;
/* Function to fetch the value at a certain index from a single chunk. */
getter_fn get_;
/* The chunk the iterator currently points to. */
int64_t ptr_chunk_;
/* The index inside the current chunk that the iterator points to. */
int64_t ptr_offset_;

public:
using iterator_category = std::random_access_iterator_tag;
using difference_type = int64_t;
using value_type = T;
using pointer = value_type*;
using reference = value_type&;

/**
* @brief Construct a new Iterator object.
*
* @param array Reference to the chunked array to iterator over.
* @param get Function to fetch the value at a certain index from a single chunk.
* @param ptr_chunk The index of the chunk to whose first index the iterator points to.
*/
Iterator(const ArrowChunkedArray& array, getter_fn get, int64_t ptr_chunk);

T operator*() const;
template <typename I>
T operator[](I idx) const;

Iterator<T>& operator++();
Iterator<T>& operator--();
Iterator<T>& operator+=(int64_t c);

template <typename V>
friend bool operator==(const Iterator<V>& a, const Iterator<V>& b);
template <typename V>
friend bool operator!=(const Iterator<V>& a, const Iterator<V>& b);
template <typename V>
friend int64_t operator-(const Iterator<V>& a, const Iterator<V>& b);
};

/**
* @brief Obtain an iterator to the beginning of the chunked array.
*
* @tparam T The value type of the iterator. May be any primitive type.
* @return Iterator<T> The iterator.
*/
template <typename T>
inline Iterator<T> begin() const;

/**
* @brief Obtain an iterator to the beginning of the chunked array.
*
* @tparam T The value type of the iterator. May be any primitive type.
* @return Iterator<T> The iterator.
*/
template <typename T>
inline Iterator<T> end() const;

template <typename V>
friend int64_t operator-(const Iterator<V>& a, const Iterator<V>& b);
};

/**
* @brief Arrow container for a list of chunked arrays.
*/
class ArrowTable {
std::vector<ArrowChunkedArray> columns_;

public:
/**
* @brief Construct a new Arrow Table object.
*
* @param n_chunks The number of chunks.
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema) {
columns_.reserve(schema->n_children);
for (int64_t j = 0; j < schema->n_children; ++j) {
std::vector<const ArrowArray*> children_chunks;
children_chunks.reserve(n_chunks);
for (int64_t k = 0; k < n_chunks; ++k) {
children_chunks.push_back(chunks[k].children[j]);
}
columns_.emplace_back(children_chunks, schema->children[j]);
}
}

/**
* @brief Get the number of rows in the table.
*
* @return int64_t The number of rows.
*/
inline int64_t get_num_rows() const { return columns_.front().get_length(); }

/**
* @brief Get the number of columns of this table.
*
* @return int64_t The column count.
*/
inline int64_t get_num_columns() const { return columns_.size(); }

/**
* @brief Get the column at a particular index.
*
* @param idx The index of the column, must me in the range `[0, num_columns)`.
* @return const ArrowChunkedArray& The chunked array for the child at the provided index.
*/
inline const ArrowChunkedArray& get_column(size_t idx) const { return this->columns_[idx]; }
};

} // namespace LightGBM

#include "arrow.tpp"

#endif /* LIGHTGBM_ARROW_H_ */
Loading

0 comments on commit 906a2bf

Please sign in to comment.