Skip to content

Commit

Permalink
[python-package] require scikit-learn>=0.24.2, make scikit-learn es…
Browse files Browse the repository at this point in the history
…timators compatible with `scikit-learn>=1.6.0dev` (#6651)

Co-authored-by: James Lamb <[email protected]>
Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent 0643230 commit 7eae66a
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .ci/test-python-latest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ python -m pip install \
'numpy>=2.0.0.dev0' \
'matplotlib>=3.10.0.dev0' \
'pandas>=3.0.0.dev0' \
'scikit-learn==1.5.*' \
'scikit-learn>=1.6.dev0' \
'scipy>=1.15.0.dev0'

python -m pip install \
Expand Down
2 changes: 1 addition & 1 deletion .ci/test-python-oldest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pip install \
'numpy==1.19.0' \
'pandas==1.1.3' \
'pyarrow==6.0.1' \
'scikit-learn==0.24.0' \
'scikit-learn==0.24.2' \
'scipy==1.6.0' \
|| exit 1
echo "done installing lightgbm's dependencies"
Expand Down
1 change: 1 addition & 0 deletions .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ if [[ $TASK == "lint" ]]; then
'mypy>=1.11.1' \
'pre-commit>=3.8.0' \
'pyarrow-core>=17.0' \
'scikit-learn>=1.5.2' \
'r-lintr>=3.1.2'
source activate $CONDA_ENV
echo "Linting Python code"
Expand Down
88 changes: 83 additions & 5 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# coding: utf-8
"""Compatibility library."""

from typing import Any, List
from typing import TYPE_CHECKING, Any, List

# scikit-learn is intentionally imported first here,
# see https://github.com/microsoft/LightGBM/issues/6509
"""sklearn"""
try:
from sklearn import __version__ as _sklearn_version
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
Expand All @@ -29,6 +30,74 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight

try:
from sklearn.utils.validation import validate_data
except ImportError:
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
def validate_data(
_estimator,
X,
y="no_validation",
accept_sparse: bool = True,
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
ensure_all_finite: bool = False,
ensure_min_samples: int = 1,
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
**ignored_kwargs,
):
# it's safe to import _num_features unconditionally because:
#
# * it was first added in scikit-learn 0.24.2
# * lightgbm cannot be used with scikit-learn versions older than that
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
#
from sklearn.utils.validation import _num_features

# _num_features() raises a TypeError on 1-dimensional input. That's a problem
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
#
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
if hasattr(X, "shape") and len(X.shape) == 1:
n_features_in_ = 1
else:
n_features_in_ = _num_features(X)

no_val_y = isinstance(y, str) and y == "no_validation"

# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if no_val_y:
X = check_array(
X,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
else:
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)

# this only needs to be updated at fit() time
_estimator.n_features_in_ = n_features_in_

# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
raise ValueError(
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
f"is expecting {_estimator._n_features} features as input."
)

if no_val_y:
return X
else:
return X, y

SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
_LGBMModelBase = BaseEstimator
Expand All @@ -38,12 +107,11 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
LGBMNotFittedError = NotFittedError
_LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array
_LGBMCheckSampleWeight = _check_sample_weight
_LGBMAssertAllFinite = assert_all_finite
_LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight
_LGBMValidateData = validate_data
except ImportError:
SKLEARN_INSTALLED = False

Expand All @@ -67,12 +135,22 @@ class _LGBMRegressorBase: # type: ignore
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckSampleWeight = None
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
_LGBMValidateData = None
_sklearn_version = None

# additional scikit-learn imports only for type hints
if TYPE_CHECKING:
# sklearn.utils.Tags can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = None


"""pandas"""
try:
Expand Down
Loading

0 comments on commit 7eae66a

Please sign in to comment.