Skip to content

Commit

Permalink
MAINT use sklearn_compat for multi-version scikit-learn compatibilities
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 8, 2024
1 parent d4d6c87 commit fed9115
Show file tree
Hide file tree
Showing 2 changed files with 649 additions and 92 deletions.
98 changes: 6 additions & 92 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,91 +12,11 @@
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y

try:
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
except ImportError:
from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import NotFittedError
try:
from sklearn.utils.validation import _check_sample_weight
except ImportError:
from sklearn.utils.validation import check_consistent_length

# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight

try:
from sklearn.utils.validation import validate_data
except ImportError:
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
def validate_data(
_estimator,
X,
y="no_validation",
accept_sparse: bool = True,
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
ensure_all_finite: bool = False,
ensure_min_samples: int = 1,
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
**ignored_kwargs,
):
# it's safe to import _num_features unconditionally because:
#
# * it was first added in scikit-learn 0.24.2
# * lightgbm cannot be used with scikit-learn versions older than that
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
#
from sklearn.utils.validation import _num_features

# _num_features() raises a TypeError on 1-dimensional input. That's a problem
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
#
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
if hasattr(X, "shape") and len(X.shape) == 1:
n_features_in_ = 1
else:
n_features_in_ = _num_features(X)

no_val_y = isinstance(y, str) and y == "no_validation"

# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if no_val_y:
X = check_array(
X,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
else:
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)

# this only needs to be updated at fit() time
_estimator.n_features_in_ = n_features_in_

# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
raise ValueError(
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
f"is expecting {_estimator._n_features} features as input."
)

if no_val_y:
return X
else:
return X, y
from sklearn.utils.validation import assert_all_finite
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import _check_sample_weight
from .sklearn_compat.utils.validation import validate_data

SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
Expand Down Expand Up @@ -144,13 +64,7 @@ class _LGBMRegressorBase: # type: ignore

# additional scikit-learn imports only for type hints
if TYPE_CHECKING:
# sklearn.utils.Tags can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = None

from .sklearn_compat.utils import Tags as _sklearn_Tags

"""pandas"""
try:
Expand Down
Loading

0 comments on commit fed9115

Please sign in to comment.