Skip to content

Commit

Permalink
[python-package] do not copy column-major numpy arrays when predicting (
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Dec 15, 2024
1 parent b33a12e commit 1090a93
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
7 changes: 2 additions & 5 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,10 +1291,7 @@ def __inner_predict_np2d(
predict_type: int,
preds: Optional[np.ndarray],
) -> Tuple[np.ndarray, int]:
if mat.dtype == np.float32 or mat.dtype == np.float64:
data = np.asarray(mat.reshape(mat.size), dtype=mat.dtype)
else: # change non-float data to float data, need to copy
data = np.array(mat.reshape(mat.size), dtype=np.float32)
data, layout = _np2d_to_np1d(mat)
ptr_data, type_ptr_data, _ = _c_float_array(data)
n_preds = self.__get_num_preds(
start_iteration=start_iteration,
Expand All @@ -1314,7 +1311,7 @@ def __inner_predict_np2d(
ctypes.c_int(type_ptr_data),
ctypes.c_int32(mat.shape[0]),
ctypes.c_int32(mat.shape[1]),
ctypes.c_int(_C_API_IS_ROW_MAJOR),
ctypes.c_int(layout),
ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
Expand Down
15 changes: 15 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4611,3 +4611,18 @@ def test_bagging_by_query_in_lambdarank():
ndcg_score_no_bagging_by_query = gbm_no_bagging_by_query.best_score["valid_0"]["ndcg@5"]
assert ndcg_score_bagging_by_query >= ndcg_score - 0.1
assert ndcg_score_no_bagging_by_query >= ndcg_score - 0.1


def test_equal_predict_from_row_major_and_col_major_data():
X_row, y = make_synthetic_regression()
assert X_row.flags["C_CONTIGUOUS"] and not X_row.flags["F_CONTIGUOUS"]
ds = lgb.Dataset(X_row, y)
params = {"num_leaves": 8, "verbose": -1}
bst = lgb.train(params, ds, num_boost_round=5)
preds_row = bst.predict(X_row)

X_col = np.asfortranarray(X_row)
assert X_col.flags["F_CONTIGUOUS"] and not X_col.flags["C_CONTIGUOUS"]
preds_col = bst.predict(X_col)

np.testing.assert_allclose(preds_row, preds_col)

0 comments on commit 1090a93

Please sign in to comment.