Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Fix inconsistency in predict() output shape for 1-tree models #6753

Merged
merged 16 commits into from
Dec 22, 2024
Merged
2 changes: 1 addition & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def predict(
if pred_leaf:
preds = preds.astype(np.int32)
is_sparse = isinstance(preds, (list, scipy.sparse.spmatrix))
if not is_sparse and preds.size != nrow:
if not is_sparse and preds.size != nrow or (pred_leaf or pred_contrib):
if preds.size % nrow == 0:
preds = preds.reshape(nrow, -1)
else:
Expand Down
26 changes: 25 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import psutil
import pytest
from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr
from sklearn.datasets import load_svmlight_file, make_blobs, make_multilabel_classification
from sklearn.datasets import load_svmlight_file, make_blobs, make_multilabel_classification, make_regression
from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score
from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split

Expand Down Expand Up @@ -2307,6 +2307,30 @@ def test_refit():
assert err_pred > new_err_pred


def test_refit_with_one_tree():
X, y = load_breast_cancer(return_X_y=True)
lgb_train = lgb.Dataset(X, label=y)
params = {"objective": "binary", "verbosity": -1}
model = lgb.train(params, lgb_train, num_boost_round=1)
model_refit = model.refit(X, y)
assert isinstance(model_refit, lgb.Booster)

X, y = make_regression(n_samples=10_000, n_features=10)
lgb_train = lgb.Dataset(X, label=y)
params = {"objective": "regression", "verbosity": -1}
model = lgb.train(params, lgb_train, num_boost_round=1)
model_refit = model.refit(X, y)
assert isinstance(model_refit, lgb.Booster)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this test for regression is going to repeat all of the same code (totally fine, repetition in test code can be helpful!), then let's please just make it a separate test case.

def test_refit_with_one_tree_regression():
   ...

def test_refit_with_one_tree_binary_classification():
   ...

def test_refit_with_one_tree_multiclass_classification():
   ...

That way, the test could be targeted individually like

pytest './tests/python_package_test/test_engine.py::test_refit_with_one_tree_binary_classification'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I add multiclass example and split the test into in 03e41c6.



def test_pred_leaf_output_shape():
X, y = make_regression(n_samples=10_000, n_features=10)
dtrain = lgb.Dataset(X, label=y)
params = {"objective": "regression", "verbosity": -1}
assert lgb.train(params, dtrain, num_boost_round=1).predict(X, pred_leaf=True).shape == (10_000, 1)
assert lgb.train(params, dtrain, num_boost_round=2).predict(X, pred_leaf=True).shape == (10_000, 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Please put this test down here by other .predict() tests:

def test_predict_stump(rng, use_init_score):

And let's please:

  • test all of .predict(X, pred_leaf=True), .predict(X, pred_contrib=True), and .predict(X)
  • change the tests names to e.g. test_predict_regression_output_shape
  • add similar tests for binary classification and 3-class multi-class classification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I add multiple shape check tests for predict including pred_leaf and pred_contrib for binary, multiclass test as you mentioned in d73f189



def test_refit_dataset_params(rng):
# check refit accepts dataset_params
X, y = load_breast_cancer(return_X_y=True)
Expand Down
Loading