diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index d187e9df5a9f..1cdd047f1857 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1488,6 +1488,12 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato assert sklearn_tags.input_tags.allow_nan is True assert sklearn_tags.input_tags.sparse is True assert sklearn_tags.target_tags.one_d_labels is True + if estimator_class is lgb.LGBMClassifier: + assert sklearn_tags.estimator_type == "classifier" + assert sklearn_tags.classifier_tags.multi_class is True + assert sklearn_tags.classifier_tags.multi_label is False + elif estimator_class is lgb.LGBMRegressor: + assert sklearn_tags.estimator_type == "regressor" @pytest.mark.parametrize("task", all_tasks)