Skip to content

Commit

Permalink
[python-package] do not copy column-major numpy arrays when creating …
Browse files Browse the repository at this point in the history
…Dataset (#6721)

* do not copy column-major numpy arrays when creating Dataset

* fix logic

* lint

* code review

* update test

* move dataset test to basic

* increase features

* assert single layout

---------

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jmoralez and StrikerRUS authored Dec 10, 2024
1 parent 33764e1 commit ae76aad
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
28 changes: 21 additions & 7 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,23 @@ def _get_sample_count(total_nrow: int, params: str) -> int:
return sample_cnt.value


def _np2d_to_np1d(mat: np.ndarray) -> Tuple[np.ndarray, int]:
if mat.dtype in (np.float32, np.float64):
dtype = mat.dtype
else:
dtype = np.float32
if mat.flags["F_CONTIGUOUS"]:
order = "F"
layout = _C_API_IS_COL_MAJOR
else:
order = "C"
layout = _C_API_IS_ROW_MAJOR
# ensure dtype and order, copies if either do not match
data = np.asarray(mat, dtype=dtype, order=order)
# flatten array without copying
return data.ravel(order=order), layout


class _MissingType(Enum):
NONE = "None"
NAN = "NaN"
Expand Down Expand Up @@ -684,7 +701,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
_C_API_DTYPE_INT32 = 2
_C_API_DTYPE_INT64 = 3

"""Matrix is row major in Python"""
"""Macro definition of data order in matrix"""
_C_API_IS_COL_MAJOR = 0
_C_API_IS_ROW_MAJOR = 1

"""Macro definition of prediction type in C API of LightGBM"""
Expand Down Expand Up @@ -2297,19 +2315,15 @@ def __init_from_np2d(
raise ValueError("Input numpy.ndarray must be 2 dimensional")

self._handle = ctypes.c_void_p()
if mat.dtype == np.float32 or mat.dtype == np.float64:
data = np.asarray(mat.reshape(mat.size), dtype=mat.dtype)
else: # change non-float data to float data, need to copy
data = np.asarray(mat.reshape(mat.size), dtype=np.float32)

data, layout = _np2d_to_np1d(mat)
ptr_data, type_ptr_data, _ = _c_float_array(data)
_safe_call(
_LIB.LGBM_DatasetCreateFromMat(
ptr_data,
ctypes.c_int(type_ptr_data),
ctypes.c_int32(mat.shape[0]),
ctypes.c_int32(mat.shape[1]),
ctypes.c_int(_C_API_IS_ROW_MAJOR),
ctypes.c_int(layout),
_c_str(params_str),
ref_dataset,
ctypes.byref(self._handle),
Expand Down
36 changes: 36 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,39 @@ def test_max_depth_warning_is_raised_if_max_depth_gte_5_and_num_leaves_omitted(c
"in params. Alternatively, pass (max_depth=-1) and just use 'num_leaves' to constrain model complexity."
)
assert expected_warning in capsys.readouterr().out


@pytest.mark.parametrize("order", ["C", "F"])
@pytest.mark.parametrize("dtype", ["float32", "int64"])
def test_no_copy_in_dataset_from_numpy_2d(rng, order, dtype):
X = rng.random(size=(100, 3))
X = np.require(X, dtype=dtype, requirements=order)
X1d, layout = lgb.basic._np2d_to_np1d(X)
if order == "F":
assert layout == lgb.basic._C_API_IS_COL_MAJOR
else:
assert layout == lgb.basic._C_API_IS_ROW_MAJOR
if dtype == "float32":
assert np.shares_memory(X, X1d)
else:
# makes a copy
assert not np.shares_memory(X, X1d)


def test_equal_datasets_from_row_major_and_col_major_data(tmp_path):
# row-major dataset
X_row, y = make_blobs(n_samples=1_000, n_features=3, centers=2)
assert X_row.flags["C_CONTIGUOUS"] and not X_row.flags["F_CONTIGUOUS"]
ds_row = lgb.Dataset(X_row, y)
ds_row_path = tmp_path / "ds_row.txt"
ds_row._dump_text(ds_row_path)

# col-major dataset
X_col = np.asfortranarray(X_row)
assert X_col.flags["F_CONTIGUOUS"] and not X_col.flags["C_CONTIGUOUS"]
ds_col = lgb.Dataset(X_col, y)
ds_col_path = tmp_path / "ds_col.txt"
ds_col._dump_text(ds_col_path)

# check datasets are equal
assert filecmp.cmp(ds_row_path, ds_col_path)

0 comments on commit ae76aad

Please sign in to comment.