From f0aa6c2e0c854becafbfc8268f724754273e128a Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 2 Dec 2024 20:48:07 -0600 Subject: [PATCH 1/3] [python-package] simplify scikit-learn 1.6+ tags support --- python-package/lightgbm/compat.py | 8 -------- python-package/lightgbm/sklearn.py | 15 +++++---------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 0b9444b0ecbf..2abab44a902a 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -14,14 +14,6 @@ from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import assert_all_finite, check_array, check_X_y - # sklearn.utils Tags types can be imported unconditionally once - # lightgbm's minimum scikit-learn version is 1.6 or higher - try: - from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags - from sklearn.utils import RegressorTags as _sklearn_RegressorTags - except ImportError: - _sklearn_ClassifierTags = None - _sklearn_RegressorTags = None try: from sklearn.exceptions import NotFittedError from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index d730b66c3556..108ef1e14498 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -40,8 +40,6 @@ _LGBMModelBase, _LGBMRegressorBase, _LGBMValidateData, - _sklearn_ClassifierTags, - _sklearn_RegressorTags, _sklearn_version, dt_DataTable, pd_DataFrame, @@ -726,7 +724,7 @@ def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]: # take whatever tags are provided by BaseEstimator, then modify # them with LightGBM-specific values return self._update_sklearn_tags_from_dict( - tags=_LGBMModelBase.__sklearn_tags__(self), + tags=super().__sklearn_tags__(), tags_dict=self._more_tags(), ) @@ -1298,10 +1296,7 @@ def _more_tags(self) -> Dict[str, Any]: return tags def __sklearn_tags__(self) -> "_sklearn_Tags": - tags = LGBMModel.__sklearn_tags__(self) - tags.estimator_type = "regressor" - tags.regressor_tags = _sklearn_RegressorTags(multi_label=False) - return tags + return super().__sklearn_tags__() def fit( # type: ignore[override] self, @@ -1360,9 +1355,9 @@ def _more_tags(self) -> Dict[str, Any]: return tags def __sklearn_tags__(self) -> "_sklearn_Tags": - tags = LGBMModel.__sklearn_tags__(self) - tags.estimator_type = "classifier" - tags.classifier_tags = _sklearn_ClassifierTags(multi_class=True, multi_label=False) + tags = super().__sklearn_tags__() + tags.classifier_tags.multi_class = True + tags.classifier_tags.multi_label = False return tags def fit( # type: ignore[override] From 16e1976ce3613e5d41cdabb76d86b0637c5f07f9 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 2 Dec 2024 21:02:22 -0600 Subject: [PATCH 2/3] add more tests on __sklearn_tags__ --- tests/python_package_test/test_sklearn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index d187e9df5a9f..1cdd047f1857 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1488,6 +1488,12 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato assert sklearn_tags.input_tags.allow_nan is True assert sklearn_tags.input_tags.sparse is True assert sklearn_tags.target_tags.one_d_labels is True + if estimator_class is lgb.LGBMClassifier: + assert sklearn_tags.estimator_type == "classifier" + assert sklearn_tags.classifier_tags.multi_class is True + assert sklearn_tags.classifier_tags.multi_label is False + elif estimator_class is lgb.LGBMRegressor: + assert sklearn_tags.estimator_type == "regressor" @pytest.mark.parametrize("task", all_tasks) From 66033970719da915db655fb8c7dcc1bb2f7e636f Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 3 Dec 2024 11:44:20 -0600 Subject: [PATCH 3/3] remove fallback classes --- python-package/lightgbm/compat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 2abab44a902a..96dee6522572 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -140,8 +140,6 @@ class _LGBMRegressorBase: # type: ignore _LGBMCheckClassificationTargets = None _LGBMComputeSampleWeight = None _LGBMValidateData = None - _sklearn_ClassifierTags = None - _sklearn_RegressorTags = None _sklearn_version = None # additional scikit-learn imports only for type hints