-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] Fix mypy errors for fit() incompatible signature #5679
[python-package] Fix mypy errors for fit() incompatible signature #5679
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for helping with this!
But I'm sorry, I disagree with most of the proposed changes. The fit()
methods of the estimators in lightgbm.sklearn
should not be identical to each other...they expect different arguments for the different machine learning tasks (e.g., group
is only relevant to ranking).
That same principle applies for the fit()
methods on classes in lightgbm.dask
, which are intended to mimic them. In addition, the lightgbm.dask
estimators cannot simply have the same types for their arguments as the underlying lightgbm.sklearn
things that they inherit from...special consideration has to be given there to how distributed training works and how to orchestrate moving data around multiple processes / machines.
If you haven't used Dask or LightGBM before, I recommend not closing this PR and not trying to attempt to fix the mypy
warnings related to these fit()
methods at this time.
python-package/lightgbm/dask.py
Outdated
@@ -1166,14 +1167,20 @@ def fit( | |||
X: _DaskMatrixLike, | |||
y: _DaskCollection, | |||
sample_weight: Optional[_DaskVectorLike] = None, | |||
init_score: Optional[_DaskCollection] = None, | |||
init_score: Optional[_DaskVectorLike] = None, | |||
group: Optional[_DaskVectorLike] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group
is not supported by DaskLGBMClassifier
. That argument is only relevant for ranking problems, not for classification or regression. Please remove this.
python-package/lightgbm/dask.py
Outdated
@@ -1166,14 +1167,20 @@ def fit( | |||
X: _DaskMatrixLike, | |||
y: _DaskCollection, | |||
sample_weight: Optional[_DaskVectorLike] = None, | |||
init_score: Optional[_DaskCollection] = None, | |||
init_score: Optional[_DaskVectorLike] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not correct. init_score
for multi-class classification can be multidimensional (e.g. 1 row per observation, 1 column per class), so _DaskVectorLike
(which only has 1-dimensional types) is not appropriate.
python-package/lightgbm/dask.py
Outdated
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, | ||
**kwargs: Any | ||
eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eval_at
is only relevant for ranking applications, not classification or regression. This doesn't belong on DaskLGBMClassifier
.
python-package/lightgbm/dask.py
Outdated
feature_name: str = 'auto', | ||
categorical_feature: str = 'auto', | ||
callbacks: Optional[List[Callable]] = None, | ||
init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding support for training continuation in the Dask interface isn't as simple as just adding this argument with the same types as the underlying scikit-learn estimators.
For example, what is the user experience if this is a str
or Path
(implying a pretrained model in a local file)? Should the client code expect to find such a file on its local storage and broadcast it to all workers? Should every worker be expected to already have an identical copy of that file at an identical filepath on its local storage?
That is it's own feature work that should be done (with new docs and tests) in #4063.
python-package/lightgbm/dask.py
Outdated
**kwargs: Any | ||
feature_name: str = 'auto', | ||
categorical_feature: str = 'auto', | ||
callbacks: Optional[List[Callable]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supporting callbacks in the Dask interface isn't as simple as just adding this argument to the signatures of these methods.
For example, unlike in the sklearn
estimators these inherit from, in Dask all callbacks must be pickleable so they can be sent from the client to workers. (See the discussion in #5080 for some background).
I don't support adding this argument to the signature like this.
python-package/lightgbm/sklearn.py
Outdated
@@ -1201,6 +1212,7 @@ def fit( | |||
eval_set=None, | |||
eval_names: Optional[List[str]] = None, | |||
eval_sample_weight=None, | |||
eval_class_weight=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eval_class_weight
is not relevant for the ranking task... it only matters for classification.
Please remove this.
python-package/lightgbm/sklearn.py
Outdated
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, | ||
eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group
, eval_at
, and eval_group
are not relevant for the classification task. They're specific to ranking. Please revert these changes adding them to LGBMClassifier.fit()
.
Since mypy does not accept different signatures for child classes, these errors regarding differences in the signature are unsolvable then. Would you rather just silence them with |
I'm willing to look at a proposal for that! I'd want to be sure that If it's possible to specifically opt out of these signature-mismatch warnings, that would be even better. |
It is possible! I've reverted the previous changes and added the override ignore tag to the fit method to silence these errors specifically. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thank you so much! I didn't know about # type: ignore[override]
before. That's exactly what I was hoping for.
The failing CI jobs aren't related to your change. Once we fix them (in #5689), I'll re-run them for this PR and merge it.
Thanks again for the help @IdoKendo ! |
This pull request has been automatically locked since there has not been any recent activity since it was closed. |
Contributes to #3867.
mypy
currently raises the following errors regarding the signature of the fit() method:Notes for Reviewers
This was tested by running mypy as documented in #3867.