Skip to content

Commit

Permalink
[python-package] Allow to pass Arrow table and array as init scores (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero authored Dec 4, 2023
1 parent 5083df1 commit f5b6bd6
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 19 deletions.
5 changes: 3 additions & 2 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
* \brief Set vector to a content in info.
* \note
* - \a group converts input datatype into ``int32``;
* - \a label and \a weight convert input datatype into ``float32``.
* - \a label and \a weight convert input datatype into ``float32``;
* - \a init_score converts input datatype into ``float64``.
* \param handle Handle of dataset
* \param field_name Field name, can be \a label, \a weight, \a group
* \param field_name Field name, can be \a label, \a weight, \a init_score, \a group
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class Metadata {
* \param init_score Initial scores, this class will manage memory for init_score.
*/
void SetInitScore(const double* init_score, data_size_t len);
void SetInitScore(const ArrowChunkedArray& array);


/*!
Expand Down Expand Up @@ -347,6 +348,9 @@ class Metadata {
void SetWeightsFromIterator(It first, It last);
/*! \brief Insert initial scores at the given index */
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
/*! \brief Set init scores from pointers to the first element and the end of an iterator. */
template <typename It>
void SetInitScoresFromIterator(It first, It last);
/*! \brief Insert queries at the given index */
void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len);
/*! \brief Set queries from pointers to the first element and the end of an iterator. */
Expand Down
28 changes: 20 additions & 8 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import scipy.sparse

from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
dt_DataTable, pa_Array, pa_ChunkedArray, pa_compute, pa_Table, pd_CategoricalDtype, pd_DataFrame,
pd_Series)
dt_DataTable, pa_Array, pa_chunked_array, pa_ChunkedArray, pa_compute, pa_Table,
pd_CategoricalDtype, pd_DataFrame, pd_Series)
from .libpath import find_lib_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,6 +84,9 @@
np.ndarray,
pd_Series,
pd_DataFrame,
pa_Table,
pa_Array,
pa_ChunkedArray,
]
_LGBM_TrainDataType = Union[
str,
Expand Down Expand Up @@ -1660,7 +1663,7 @@ def __init__(
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
Init score for Dataset.
feature_name : list of str, or 'auto', optional (default="auto")
Feature names.
Expand Down Expand Up @@ -2440,7 +2443,7 @@ def create_valid(
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
Init score for Dataset.
params : dict or None, optional (default=None)
Other parameters for validation Dataset.
Expand Down Expand Up @@ -2547,7 +2550,7 @@ def _reverse_update_params(self) -> "Dataset":
def set_field(
self,
field_name: str,
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Array, pa_ChunkedArray]]
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Table, pa_Array, pa_ChunkedArray]]
) -> "Dataset":
"""Set property into the Dataset.
Expand Down Expand Up @@ -2576,7 +2579,16 @@ def set_field(
return self

# If the data is a arrow data, we can just pass it to C
if _is_pyarrow_array(data):
if _is_pyarrow_array(data) or _is_pyarrow_table(data):
# If a table is being passed, we concatenate the columns. This is only valid for
# 'init_score'.
if _is_pyarrow_table(data):
if field_name != "init_score":
raise ValueError(f"pyarrow tables are not supported for field '{field_name}'")
data = pa_chunked_array([
chunk for array in data.columns for chunk in array.chunks # type: ignore
])

c_array = _export_arrow_to_c(data)
_safe_call(_LIB.LGBM_DatasetSetFieldFromArrow(
self._handle,
Expand Down Expand Up @@ -2869,7 +2881,7 @@ def set_init_score(
Parameters
----------
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None
Init score for Booster.
Returns
Expand Down Expand Up @@ -4443,7 +4455,7 @@ def refit(
.. versionadded:: 4.0.0
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
Init score for ``data``.
.. versionadded:: 4.0.0
Expand Down
2 changes: 2 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(self, *args, **kwargs):
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer
Expand Down Expand Up @@ -243,6 +244,7 @@ class pa_compute: # type: ignore
all = None
equal = None

pa_chunked_array = None
arrow_is_integer = None
arrow_is_floating = None

Expand Down
2 changes: 2 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
metadata_.SetLabel(ca);
} else if (name == std::string("weight") || name == std::string("weights")) {
metadata_.SetWeights(ca);
} else if (name == std::string("init_score")) {
metadata_.SetInitScore(ca);
} else if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQuery(ca);
} else {
Expand Down
28 changes: 20 additions & 8 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,32 +355,44 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
}
}

void Metadata::SetInitScore(const double* init_score, data_size_t len) {
template <typename It>
void Metadata::SetInitScoresFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
if (init_score == nullptr || len == 0) {
// Clear init scores on empty input
if (last - first == 0) {
init_score_.clear();
num_init_score_ = 0;
return;
}
if ((len % num_data_) != 0) {
if (((last - first) % num_data_) != 0) {
Log::Fatal("Initial score size doesn't match data size");
}
if (init_score_.empty()) { init_score_.resize(len); }
num_init_score_ = len;
if (init_score_.empty()) {
init_score_.resize(last - first);
}
num_init_score_ = last - first;

#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_init_score_ >= 1024)
for (int64_t i = 0; i < num_init_score_; ++i) {
init_score_[i] = Common::AvoidInf(init_score[i]);
init_score_[i] = Common::AvoidInf(first[i]);
}
init_score_load_from_file_ = false;

#ifdef USE_CUDA
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetInitScore(init_score_.data(), len);
cuda_metadata_->SetInitScore(init_score_.data(), init_score_.size());
}
#endif // USE_CUDA
}

void Metadata::SetInitScore(const double* init_score, data_size_t len) {
SetInitScoresFromIterator(init_score, init_score + len);
}

void Metadata::SetInitScore(const ArrowChunkedArray& array) {
SetInitScoresFromIterator(array.begin<double>(), array.end<double>());
}

void Metadata::InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size) {
if (num_init_score_ <= 0) {
Log::Fatal("Inserting initial score data into dataset with no initial scores");
Expand Down
45 changes: 44 additions & 1 deletion tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_dataset_construct_weights_none():
["array_type", "weight_data"],
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
)
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
@pytest.mark.parametrize("arrow_type", _FLOAT_TYPES)
def test_dataset_construct_weights(array_type, weight_data, arrow_type):
data = generate_dummy_arrow_table()
weights = array_type(weight_data, type=arrow_type)
Expand Down Expand Up @@ -210,3 +210,46 @@ def test_dataset_construct_groups(array_type, group_data, arrow_type):

expected = np.array([0, 2, 5], dtype=np.int32)
np_assert_array_equal(expected, dataset.get_field("group"), strict=True)


# ----------------------------------------- INIT SCORES ----------------------------------------- #


@pytest.mark.parametrize(
["array_type", "init_score_data"],
[
(pa.array, [0, 1, 2, 3, 3]),
(pa.chunked_array, [[0, 1, 2], [3, 3]]),
(pa.chunked_array, [[], [0, 1, 2], [3, 3]]),
(pa.chunked_array, [[0, 1], [], [], [2], [3, 3], []]),
],
)
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES)
def test_dataset_construct_init_scores_array(
array_type: Any, init_score_data: Any, arrow_type: Any
):
data = generate_dummy_arrow_table()
init_scores = array_type(init_score_data, type=arrow_type)
dataset = lgb.Dataset(data, init_score=init_scores, params=dummy_dataset_params())
dataset.construct()

expected = np.array([0, 1, 2, 3, 3], dtype=np.float64)
np_assert_array_equal(expected, dataset.get_init_score(), strict=True)


def test_dataset_construct_init_scores_table():
data = generate_dummy_arrow_table()
init_scores = pa.Table.from_arrays(
[
generate_random_arrow_array(5, seed=1),
generate_random_arrow_array(5, seed=2),
generate_random_arrow_array(5, seed=3),
],
names=["a", "b", "c"],
)
dataset = lgb.Dataset(data, init_score=init_scores, params=dummy_dataset_params())
dataset.construct()

actual = dataset.get_init_score()
expected = init_scores.to_pandas().to_numpy().astype(np.float64)
np_assert_array_equal(expected, actual, strict=True)

0 comments on commit f5b6bd6

Please sign in to comment.