Skip to content

Commit

Permalink
Merge pull request #66 from y0z/feature/sklearn
Browse files Browse the repository at this point in the history
Add sklearn integration
  • Loading branch information
contramundum53 authored Feb 6, 2024
2 parents 6462be5 + 29e74be commit 2734577
Show file tree
Hide file tree
Showing 7 changed files with 1,546 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc
* [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras))
* [MXNet](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#mxnet) ([example](https://github.com/optuna/optuna-examples/tree/main/mxnet))
* [SHAP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#shap)
* [sklearn](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#sklearn) ([example](https://github.com/optuna/optuna-examples/tree/main/sklearn/sklearn_optuna_search_cv_simple.py))
* [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
* [TensorBoard](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorboard) ([example](https://github.com/optuna/optuna-examples/tree/main/tensorboard/tensorboard_simple.py))
* [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))
Expand Down
49 changes: 29 additions & 20 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ The former is provided for backward compatibility.

For most of the ML frameworks supported by Optuna, the corresponding Optuna integration class serves only to implement a callback object and functions, compliant with the framework's specific callback API, to be called with each intermediate step in the model training. The functionality implemented in these callbacks across the different ML frameworks includes:

(1) Reporting intermediate model scores back to the Optuna trial using :func:`optuna.trial.Trial.report`,
(2) According to the results of :func:`optuna.trial.Trial.should_prune`, pruning the current model by raising :func:`optuna.TrialPruned`, and
(3) Reporting intermediate Optuna data such as the current trial number back to the framework, as done in :class:`~optuna.integration.MLflowCallback`.
(1) Reporting intermediate model scores back to the Optuna trial using `optuna.trial.Trial.report <https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.report>`_,
(2) According to the results of `optuna.trial.Trial.should_prune <https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.should_prune>`_, pruning the current model by raising `optuna.TrialPruned <https://optuna.readthedocs.io/en/stable/reference/generated/optuna.TrialPruned.html#optuna.TrialPruned>`_, and
(3) Reporting intermediate Optuna data such as the current trial number back to the framework, as done in :class:`~optuna_integration.MLflowCallback`.

For scikit-learn, an integrated :class:`~optuna.integration.OptunaSearchCV` estimator is available that combines scikit-learn BaseEstimator functionality with access to a class-level ``Study`` object.
For scikit-learn, an integrated :class:`~optuna_integration.OptunaSearchCV` estimator is available that combines scikit-learn BaseEstimator functionality with access to a class-level ``Study`` object.

AllenNLP
--------
Expand All @@ -23,9 +23,9 @@ AllenNLP
:toctree: generated/
:nosignatures:

optuna.integration.AllenNLPExecutor
optuna.integration.allennlp.dump_best_config
optuna.integration.AllenNLPPruningCallback
optuna_integration.AllenNLPExecutor
optuna_integration.allennlp.dump_best_config
optuna_integration.AllenNLPPruningCallback

Catalyst
--------
Expand All @@ -34,7 +34,7 @@ Catalyst
:toctree: generated/
:nosignatures:

optuna.integration.CatalystPruningCallback
optuna_integration.CatalystPruningCallback

CatBoost
--------
Expand All @@ -43,7 +43,7 @@ CatBoost
:toctree: generated/
:nosignatures:

optuna.integration.CatBoostPruningCallback
optuna_integration.CatBoostPruningCallback

Chainer
-------
Expand All @@ -52,8 +52,8 @@ Chainer
:toctree: generated/
:nosignatures:

optuna.integration.ChainerPruningExtension
optuna.integration.ChainerMNStudy
optuna_integration.ChainerPruningExtension
optuna_integration.ChainerMNStudy

Dask
----
Expand All @@ -71,9 +71,9 @@ fast.ai
:toctree: generated/
:nosignatures:

optuna.integration.FastAIV1PruningCallback
optuna.integration.FastAIV2PruningCallback
optuna.integration.FastAIPruningCallback
optuna_integration.FastAIV1PruningCallback
optuna_integration.FastAIV2PruningCallback
optuna_integration.FastAIPruningCallback

Keras
-----
Expand All @@ -82,7 +82,7 @@ Keras
:toctree: generated/
:nosignatures:

optuna.integration.KerasPruningCallback
optuna_integration.KerasPruningCallback

MXNet
-----
Expand All @@ -91,7 +91,7 @@ MXNet
:toctree: generated/
:nosignatures:

optuna.integration.MXNetPruningCallback
optuna_integration.MXNetPruningCallback

SHAP
----
Expand All @@ -100,7 +100,16 @@ SHAP
:toctree: generated/
:nosignatures:

optuna.integration.ShapleyImportanceEvaluator
optuna_integration.ShapleyImportanceEvaluator

sklearn
-------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna_integration.OptunaSearchCV

skorch
------
Expand All @@ -109,7 +118,7 @@ skorch
:toctree: generated/
:nosignatures:

optuna.integration.SkorchPruningCallback
optuna_integration.SkorchPruningCallback

TensorBoard
-----------
Expand All @@ -118,7 +127,7 @@ TensorBoard
:toctree: generated/
:nosignatures:

optuna.integration.TensorBoardCallback
optuna_integration.TensorBoardCallback

TensorFlow
----------
Expand All @@ -127,4 +136,4 @@ TensorFlow
:toctree: generated/
:nosignatures:

optuna.integration.TFKerasPruningCallback
optuna_integration.TFKerasPruningCallback
110 changes: 110 additions & 0 deletions optuna_integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import sys
from types import ModuleType
from typing import Any
from typing import TYPE_CHECKING


_import_structure = {
"allennlp": ["AllenNLPExecutor", "AllenNLPPruningCallback"],
"catalyst": ["CatalystPruningCallback"],
"catboost": ["CatBoostPruningCallback"],
"chainer": ["ChainerPruningExtension"],
"chainermn": ["ChainerMNStudy"],
"fastaiv1": ["FastAIV1PruningCallback"],
"fastaiv2": ["FastAIV2PruningCallback", "FastAIPruningCallback"],
"keras": ["KerasPruningCallback"],
"mxnet": ["MXNetPruningCallback"],
"shap": ["ShapleyImportanceEvaluator"],
"sklearn": ["OptunaSearchCV"],
"skorch": ["SkorchPruningCallback"],
"tensorboard": ["TensorBoardCallback"],
"tensorflow": ["TensorFlowPruningHook"],
"tfkeras": ["TFKerasPruningCallback"],
}


if TYPE_CHECKING:
from optuna_integration.allennlp import AllenNLPExecutor
from optuna_integration.allennlp import AllenNLPPruningCallback
from optuna_integration.catalyst import CatalystPruningCallback
from optuna_integration.catboost import CatBoostPruningCallback
from optuna_integration.chainer import ChainerPruningExtension
from optuna_integration.chainermn import ChainerMNStudy
from optuna_integration.fastaiv1 import FastAIV1PruningCallback
from optuna_integration.fastaiv2 import FastAIPruningCallback
from optuna_integration.fastaiv2 import FastAIV2PruningCallback
from optuna_integration.keras import KerasPruningCallback
from optuna_integration.mxnet import MXNetPruningCallback
from optuna_integration.shap import ShapleyImportanceEvaluator
from optuna_integration.sklearn import OptunaSearchCV
from optuna_integration.skorch import SkorchPruningCallback
from optuna_integration.tensorboard import TensorBoardCallback
from optuna_integration.tensorflow import TensorFlowPruningHook
from optuna_integration.tfkeras import TFKerasPruningCallback
else:

class _IntegrationModule(ModuleType):
"""Module class that implements `optuna_integration` package.
This class applies lazy import under `optuna_integration`, where submodules are imported
when they are actually accessed. Otherwise, `import optuna` becomes much slower because it
imports all submodules and their dependencies (e.g., chainer, keras, lightgbm) all at once.
"""

__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]

_modules = set(_import_structure.keys())
_class_to_module = {}
for key, values in _import_structure.items():
for value in values:
_class_to_module[value] = key

def __getattr__(self, name: str) -> Any:
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError("module {} has no attribute {}".format(self.__name__, name))

setattr(self, name, value)
return value

def _get_module(self, module_name: str) -> ModuleType:
import importlib

try:
return importlib.import_module("." + module_name, self.__name__)
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Optuna's integration modules for third-party libraries have started "
"migrating from Optuna itself to a package called `optuna-integration`. "
"The module you are trying to use has already been migrated to "
"`optuna-integration`. Please install the package by running "
"`pip install optuna-integration`."
)

sys.modules[__name__] = _IntegrationModule(__name__)

__all__ = [
"AllenNLPExecutor",
"AllenNLPPruningCallback",
"CatalystPruningCallback",
"CatBoostPruningCallback",
"ChainerMNStudy",
"ChainerPruningExtension",
"FastAIPruningCallback",
"FastAIV1PruningCallback",
"FastAIV2PruningCallback",
"KerasPruningCallback",
"MXNetPruningCallback",
"OptunaSearchCV",
"ShapleyImportanceEvaluator",
"SkorchPruningCallback",
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
]
3 changes: 1 addition & 2 deletions optuna_integration/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@


with try_import() as _imports:
from sklearn.ensemble import RandomForestRegressor

from shap import TreeExplainer
from sklearn.ensemble import RandomForestRegressor


@experimental_class("3.0.0")
Expand Down
Loading

0 comments on commit 2734577

Please sign in to comment.