Skip to content

Commit

Permalink
FEAT add scikit-learn wrappers (#20599)
Browse files Browse the repository at this point in the history
* FEAT add scikit-learn wrappers

* import cleanup

* run black

* linters

* lint

* add scikit-learn to requirements-common

* generate public api

* fix tests for sklearn 1.5

* check fixes

* skip numpy tests

* xfail instead of skip

* apply review comments

* change names to SKL* and add transformer example

* fix API and imports

* fix for new sklearn

* sklearn1.6 test

* review comments and remove random_state

* add another skipped test

* rename file

* change imports

* unindent

* docstrings
  • Loading branch information
adrinjalali authored Dec 12, 2024
1 parent 8465c3d commit 32a642d
Show file tree
Hide file tree
Showing 11 changed files with 806 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from keras.api import utils
from keras.api import version
from keras.api import visualization
from keras.api import wrappers

# END DO NOT EDIT.

Expand Down
1 change: 1 addition & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.api import wrappers
from keras.src.backend import Variable
from keras.src.backend import device
from keras.src.backend import name_scope
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.api import wrappers
from keras.api._tf_keras.keras import backend
from keras.api._tf_keras.keras import layers
from keras.api._tf_keras.keras import losses
Expand Down
9 changes: 9 additions & 0 deletions keras/api/_tf_keras/keras/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""DO NOT EDIT.
This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer
9 changes: 9 additions & 0 deletions keras/api/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""DO NOT EDIT.
This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer
5 changes: 5 additions & 0 deletions keras/src/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer

__all__ = ["SKLearnClassifier", "SKLearnRegressor", "SKLearnTransformer"]
119 changes: 119 additions & 0 deletions keras/src/wrappers/fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import sklearn
from packaging.version import parse as parse_version
from sklearn import get_config

sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)

if sklearn_version < parse_version("1.6"):

def patched_more_tags(estimator, expected_failed_checks):
import copy

from sklearn.utils._tags import _safe_tags

original_tags = copy.deepcopy(_safe_tags(estimator))

def patched_more_tags(self):
original_tags.update({"_xfail_checks": expected_failed_checks})
return original_tags

estimator.__class__._more_tags = patched_more_tags
return estimator

def parametrize_with_checks(
estimators,
*,
legacy: bool = True,
expected_failed_checks=None,
):
# legacy is not supported and ignored
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001

estimators = [
patched_more_tags(estimator, expected_failed_checks(estimator))
for estimator in estimators
]

return parametrize_with_checks(estimators)
else:
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001


def _validate_data(estimator, *args, **kwargs):
"""Validate the input data.
wrapper for sklearn.utils.validation.validate_data or
BaseEstimator._validate_data depending on the scikit-learn version.
TODO: remove when minimum scikit-learn version is 1.6
"""
try:
# scikit-learn >= 1.6
from sklearn.utils.validation import validate_data

return validate_data(estimator, *args, **kwargs)
except ImportError:
return estimator._validate_data(*args, **kwargs)
except:
raise


def type_of_target(y, input_name="", *, raise_unknown=False):
# fix for raise_unknown which is introduced in scikit-learn 1.6
from sklearn.utils.multiclass import type_of_target

def _raise_or_return(target_type):
"""Depending on the value of raise_unknown, either raise an error or
return 'unknown'.
"""
if raise_unknown and target_type == "unknown":
input = input_name if input_name else "data"
raise ValueError(f"Unknown label type for {input}: {y!r}")
else:
return target_type

target_type = type_of_target(y, input_name=input_name)
return _raise_or_return(target_type)


def _routing_enabled():
"""Return whether metadata routing is enabled.
Returns:
enabled : bool
Whether metadata routing is enabled. If the config is not set, it
defaults to False.
TODO: remove when the config key is no longer available in scikit-learn
"""
return get_config().get("enable_metadata_routing", False)


def _raise_for_params(params, owner, method):
"""Raise an error if metadata routing is not enabled and params are passed.
Parameters:
params : dict
The metadata passed to a method.
owner : object
The object to which the method belongs.
method : str
The name of the method, e.g. "fit".
Raises:
ValueError
If metadata routing is not enabled and params are passed.
"""
caller = (
f"{owner.__class__.__name__}.{method}"
if method
else owner.__class__.__name__
)
if not _routing_enabled() and params:
raise ValueError(
f"Passing extra keyword arguments to {caller} is only supported if"
" enable_metadata_routing=True, which you can set using"
" `sklearn.set_config`. See the User Guide"
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
f" details. Extra parameters passed are: {set(params)}"
)
119 changes: 119 additions & 0 deletions keras/src/wrappers/sklearn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Tests using Scikit-Learn's bundled estimator_checks."""

from contextlib import contextmanager

import pytest

import keras
from keras.src.backend import floatx
from keras.src.backend import set_floatx
from keras.src.layers import Dense
from keras.src.layers import Input
from keras.src.models import Model
from keras.src.wrappers import SKLearnClassifier
from keras.src.wrappers import SKLearnRegressor
from keras.src.wrappers import SKLearnTransformer
from keras.src.wrappers.fixes import parametrize_with_checks


def dynamic_model(X, y, loss, layers=[10]):
"""Creates a basic MLP classifier dynamically choosing binary/multiclass
classification loss and ouput activations.
"""
n_features_in = X.shape[1]
inp = Input(shape=(n_features_in,))

hidden = inp
for layer_size in layers:
hidden = Dense(layer_size, activation="relu")(hidden)

n_outputs = y.shape[1] if len(y.shape) > 1 else 1
out = [Dense(n_outputs, activation="softmax")(hidden)]
model = Model(inp, out)
model.compile(loss=loss, optimizer="rmsprop")

return model


@contextmanager
def use_floatx(x: str):
"""Context manager to temporarily
set the keras backend precision.
"""
_floatx = floatx()
set_floatx(x)
try:
yield
finally:
set_floatx(_floatx)


EXPECTED_FAILED_CHECKS = {
"SKLearnClassifier": {
"check_classifiers_regression_target": "not an issue in sklearn>=1.6",
"check_parameters_default_constructible": (
"not an issue in sklearn>=1.6"
),
"check_classifiers_one_label_sample_weights": (
"0 sample weight is not ignored"
),
"check_classifiers_classes": (
"with small test cases the estimator returns not all classes "
"sometimes"
),
"check_classifier_data_not_an_array": (
"This test assumes reproducibility in fit."
),
"check_supervised_y_2d": "This test assumes reproducibility in fit.",
"check_fit_idempotent": "This test assumes reproducibility in fit.",
},
"SKLearnRegressor": {
"check_parameters_default_constructible": (
"not an issue in sklearn>=1.6"
),
},
"SKLearnTransformer": {
"check_parameters_default_constructible": (
"not an issue in sklearn>=1.6"
),
},
}


@parametrize_with_checks(
estimators=[
SKLearnClassifier(
model=dynamic_model,
model_kwargs={
"loss": "categorical_crossentropy",
"layers": [20, 20, 20],
},
fit_kwargs={"epochs": 5},
),
SKLearnRegressor(
model=dynamic_model,
model_kwargs={"loss": "mse"},
),
SKLearnTransformer(
model=dynamic_model,
model_kwargs={"loss": "mse"},
),
],
expected_failed_checks=lambda estimator: EXPECTED_FAILED_CHECKS[
type(estimator).__name__
],
)
def test_sklearn_estimator_checks(estimator, check):
"""Checks that can be passed with sklearn's default tolerances
and in a single epoch.
"""
try:
check(estimator)
except Exception as exc:
if keras.config.backend() == "numpy" and (
isinstance(exc, NotImplementedError)
or "NotImplementedError" in str(exc)
):
pytest.xfail("Backend not implemented")
else:
raise
Loading

0 comments on commit 32a642d

Please sign in to comment.