From 16e1976ce3613e5d41cdabb76d86b0637c5f07f9 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 2 Dec 2024 21:02:22 -0600 Subject: [PATCH] add more tests on __sklearn_tags__ --- tests/python_package_test/test_sklearn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index d187e9df5a9f..1cdd047f1857 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1488,6 +1488,12 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato assert sklearn_tags.input_tags.allow_nan is True assert sklearn_tags.input_tags.sparse is True assert sklearn_tags.target_tags.one_d_labels is True + if estimator_class is lgb.LGBMClassifier: + assert sklearn_tags.estimator_type == "classifier" + assert sklearn_tags.classifier_tags.multi_class is True + assert sklearn_tags.classifier_tags.multi_label is False + elif estimator_class is lgb.LGBMRegressor: + assert sklearn_tags.estimator_type == "regressor" @pytest.mark.parametrize("task", all_tasks)