Skip to content

Commit

Permalink
Adapt to scikit-learn 1.6 estimator tag changes (#11021)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Dec 4, 2024
1 parent 23aadda commit 337265a
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,13 @@ credentials.csv
.bloop

# python tests
*.bin
demo/**/*.txt
*.dmatrix
.hypothesis
__MACOSX/
model*.json
/tests/python/models/models/

# R tests
*.htm
Expand Down
2 changes: 2 additions & 0 deletions python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ disable = [
"import-error",
"attribute-defined-outside-init",
"import-outside-toplevel",
"too-few-public-methods",
"too-many-ancestors",
"too-many-nested-blocks",
"unsubscriptable-object",
"useless-object-inheritance"
Expand Down
27 changes: 19 additions & 8 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,43 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:

# sklearn
try:
from sklearn import __version__ as _sklearn_version
from sklearn.base import BaseEstimator as XGBModelBase
from sklearn.base import ClassifierMixin as XGBClassifierBase
from sklearn.base import RegressorMixin as XGBRegressorBase
from sklearn.preprocessing import LabelEncoder

try:
from sklearn.model_selection import KFold as XGBKFold
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
except ImportError:
from sklearn.cross_validation import KFold as XGBKFold
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold

# sklearn.utils Tags types can be imported unconditionally once
# xgboost's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = object

SKLEARN_INSTALLED = True

except ImportError:
SKLEARN_INSTALLED = False

# used for compatibility without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object
LabelEncoder = object
class XGBModelBase: # type: ignore[no-redef]
"""Dummy class for sklearn.base.BaseEstimator."""

class XGBClassifierBase: # type: ignore[no-redef]
"""Dummy class for sklearn.base.ClassifierMixin."""

class XGBRegressorBase: # type: ignore[no-redef]
"""Dummy class for sklearn.base.RegressorMixin."""

XGBKFold = None
XGBStratifiedKFold = None

_sklearn_Tags = object
_sklearn_version = object


_logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def c_array(
def from_array_interface(interface: dict) -> NumpyOrCupy:
"""Convert array interface to numpy or cupy array"""

class Array: # pylint: disable=too-few-public-methods
class Array:
"""Wrapper type for communicating with numpy and cupy."""

_interface: Optional[dict] = None
Expand Down
1 change: 0 additions & 1 deletion python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines
# pylint: disable=too-few-public-methods
"""
Dask extensions for distributed training
----------------------------------------
Expand Down
80 changes: 72 additions & 8 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
XGBClassifierBase,
XGBModelBase,
XGBRegressorBase,
_sklearn_Tags,
_sklearn_version,
import_cupy,
)
from .config import config_context
Expand All @@ -54,7 +56,7 @@
from .training import train


class XGBRankerMixIn: # pylint: disable=too-few-public-methods
class XGBRankerMixIn:
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
base classes.
Expand All @@ -79,7 +81,7 @@ def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl


class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
class _SklObjWProto(Protocol):
def __call__(
self,
y_true: ArrayLike,
Expand Down Expand Up @@ -805,6 +807,41 @@ def _more_tags(self) -> Dict[str, bool]:
tags["non_deterministic"] = True
return tags

@staticmethod
def _update_sklearn_tags_from_dict(
*,
tags: _sklearn_Tags,
tags_dict: Dict[str, bool],
) -> _sklearn_Tags:
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes.
``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags.
ref: https://github.com/scikit-learn/scikit-learn/pull/29677
This method handles updating that instance based on the values in ``self._more_tags()``.
"""
tags.non_deterministic = tags_dict.get("non_deterministic", False)
tags.no_validation = tags_dict["no_validation"]
tags.input_tags.allow_nan = tags_dict["allow_nan"]
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
# XGBModelBase.__sklearn_tags__() cannot be called unconditionally,
# because that method isn't defined for scikit-learn<1.6
if not hasattr(XGBModelBase, "__sklearn_tags__"):
err_msg = (
"__sklearn_tags__() should not be called when using scikit-learn<1.6. "
f"Detected version: {_sklearn_version}"
)
raise AttributeError(err_msg)

# take whatever tags are provided by BaseEstimator, then modify
# them with XGBoost-specific values
return self._update_sklearn_tags_from_dict(
tags=super().__sklearn_tags__(), # pylint: disable=no-member
tags_dict=self._more_tags(),
)

def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")

Expand Down Expand Up @@ -898,13 +935,27 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Get parameters."""
# Based on: https://stackoverflow.com/questions/59248211
# The basic flow in `get_params` is:
# 0. Return parameters in subclass first, by using inspect.
# 1. Return parameters in `XGBModel` (the base class).
# 0. Return parameters in subclass (self.__class__) first, by using inspect.
# 1. Return parameters in all parent classes (especially `XGBModel`).
# 2. Return whatever in `**kwargs`.
# 3. Merge them.
#
# This needs to accommodate being called recursively in the following
# inheritance graphs (and similar for classification and ranking):
#
# XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator
# XGBRegressor -> XGBModel -> BaseEstimator
# XGBModel -> BaseEstimator
#
params = super().get_params(deep)
cp = copy.copy(self)
cp.__class__ = cp.__class__.__bases__[0]
# If the immediate parent defines get_params(), use that.
if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
cp.__class__ = cp.__class__.__bases__[0]
# Otherwise, skip it and assume the next class will have it.
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
else:
cp.__class__ = cp.__class__.__bases__[1]
params.update(cp.__class__.get_params(cp, deep))
# if kwargs is a dict, update params accordingly
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):
Expand Down Expand Up @@ -1481,7 +1532,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
Number of boosting rounds.
""",
)
class XGBClassifier(XGBModel, XGBClassifierBase):
class XGBClassifier(XGBClassifierBase, XGBModel):
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@_deprecate_positional_args
def __init__(
Expand All @@ -1497,6 +1548,12 @@ def _more_tags(self) -> Dict[str, bool]:
tags["multilabel"] = True
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
tags = super().__sklearn_tags__()
tags_dict = self._more_tags()
tags.classifier_tags.multi_label = tags_dict["multilabel"]
return tags

@_deprecate_positional_args
def fit(
self,
Expand Down Expand Up @@ -1769,7 +1826,7 @@ def fit(
"Implementation of the scikit-learn API for XGBoost regression.",
["estimators", "model", "objective"],
)
class XGBRegressor(XGBModel, XGBRegressorBase):
class XGBRegressor(XGBRegressorBase, XGBModel):
# pylint: disable=missing-docstring
@_deprecate_positional_args
def __init__(
Expand All @@ -1783,6 +1840,13 @@ def _more_tags(self) -> Dict[str, bool]:
tags["multioutput_only"] = False
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
tags = super().__sklearn_tags__()
tags_dict = self._more_tags()
tags.target_tags.multi_output = tags_dict["multioutput"]
tags.target_tags.single_output = not tags_dict["multioutput_only"]
return tags


@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",
Expand Down Expand Up @@ -1910,7 +1974,7 @@ def _get_qid(
`qid` can be a special column of input `X` instead of a separated parameter, see
:py:meth:`fit` for more info.""",
)
class XGBRanker(XGBModel, XGBRankerMixIn):
class XGBRanker(XGBRankerMixIn, XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@_deprecate_positional_args
def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any):
Expand Down
4 changes: 2 additions & 2 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import base64

# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
# pylint: disable=fixme, protected-access, no-member, invalid-name
# pylint: disable=too-many-lines, too-many-branches
import json
import logging
import os
Expand Down
3 changes: 1 addition & 2 deletions python-package/xgboost/spark/estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Xgboost pyspark integration submodule for estimator API."""

# pylint: disable=too-many-ancestors
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=fixme, protected-access, no-member, invalid-name
# pylint: disable=unused-argument, too-many-locals

import warnings
Expand Down
1 change: 0 additions & 1 deletion python-package/xgboost/spark/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Dict

# pylint: disable=too-few-public-methods
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import Param, Params

Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_default_params_from_func(
return filtered_params_dict


class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
class CommunicatorContext(CCtx):
"""Context with PySpark specific task ID."""

def __init__(self, context: BarrierTaskContext, **args: CollArgsVals) -> None:
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/testing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def is_binary(self) -> bool:
return self.max_rel == 1


class PBM: # pylint: disable=too-few-public-methods
class PBM:
"""Simulate click data with position bias model. There are other models available in
`ULTRA <https://github.com/ULTR-Community/ULTRA.git>`_ like the cascading model.
Expand Down
Loading

0 comments on commit 337265a

Please sign in to comment.