-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] Fix inconsistency in predict()
output shape for 1-tree models
#6753
Changes from 3 commits
4c627cb
52c6a31
7bdbacb
088a2b9
b6bb3c9
ba39a6f
46a0ddc
03e41c6
c09202c
d73f189
eb256bc
e9101a1
6ad5c49
2b4bfd7
75489a5
6457499
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2307,6 +2307,15 @@ def test_refit(): | |||||||||
assert err_pred > new_err_pred | ||||||||||
|
||||||||||
|
||||||||||
def test_refit_with_one_tree(): | ||||||||||
X, y = load_breast_cancer(return_X_y=True) | ||||||||||
lgb_train = lgb.Dataset(X, label=y) | ||||||||||
params = {"objective": "binary", "num_trees": 1, "verbosity": -1} | ||||||||||
model = lgb.train(params, lgb_train, num_boost_round=1) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Passing both Can you also please add tests for regression and multi-class classification? Note that #6737 was specifically about regression... we should make sure this fixes that issue for that specific case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's fine, but please add a multi-class classification test anyway. That'll help prevent us from introducing this bug for multi-class classification in the future. |
||||||||||
model_refit = model.refit(X, y) | ||||||||||
assert isinstance(model_refit, lgb.Booster) | ||||||||||
|
||||||||||
|
||||||||||
def test_refit_dataset_params(rng): | ||||||||||
# check refit accepts dataset_params | ||||||||||
X, y = load_breast_cancer(return_X_y=True) | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add
or pred_contrib
here as well?..There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it seems unnecessary since there's no place where the shape is used after
pred_contrib=True
.pred_leaf
is necessary because ofLightGBM/python-package/lightgbm/basic.py
Lines 4882 to 4888 in b33a12e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RektPunk
Booster.predict()
is an important part of LightGBM's public API. If it's producing incorrect results, that needs to be fixed.We've set the expectation that
.predict(X, pred_leaf=True)
, for example, returns a matrix with shape(X.shape[0], num_trees)
for regression, and right now that is not true for a single-tree model.Even if no other code inside
lightgbm
referenced the.shape
attribute of that output, this would still be a bug and we'd still want to fix it... because any other code usinglightgbm
should be able to form a consistent expectation of the return shape, and not need to have a special case for 1-tree models.I think we should, for completeness, but the existing condition is technically already enough to get the correct output shape.
preds.size != nrow
will always be true forpred_contrib
, because the output contains 1 value per feature AND the Shapley base values.In other words... it will always have
nrow*(num_features + 1)
elements, so the reshaping will always be done, so there's no bug here forpred_contrib
like there is forpred_leaf
.So please, let's do the following:
if not is_sparse and preds.size != nrow or (pred_leaf or pred_contrib)
pred_leaf
andpred_contrib
with 1-iteration and 2-iteration models (using the snippet I provided above as a reference)@RektPunk if you'd prefer that I add those additional tests, let me know and I'll push to your branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for quick review. As you said, the condition is changed in 088a2b9, and explicit test is added in ba39a6f. I'm not sure the exact location of newly added test though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the condition as
if not is_sparse and (preds.size != nrow or pred_leaf or pred_contrib)
since previous one make test failed: link.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah great, thank you! Sorry for suggesting the wrong form... I never remember exactly how a condition like
if A and B or C or D
is evaluated in Python 😅The parentheses make it easier to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries at all, Thank you so much for your kind suggestion :)