From f42d94f39f55a1dd24b5722139802593afe73995 Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Wed, 31 Jan 2024 12:08:01 +0900 Subject: [PATCH 01/11] Not use black 24.* --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0821b2a6..d31d46e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ test = [ "fakeredis[lua]", ] checking = [ - "black", + "black<24.0.0", "blackdoc", "hacking", "isort", From d4dc81b70ac342400de640d7660aa7d6e8957caa Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 1 Feb 2024 20:35:47 +0900 Subject: [PATCH 02/11] Fetched from https://github.com/optuna/optuna/commit/2b1572f447d95d401db62e9305b3222c705ef986. --- optuna_integration/sklearn.py | 948 ++++++++++++++++++++++++++++++++++ 1 file changed, 948 insertions(+) create mode 100644 optuna_integration/sklearn.py diff --git a/optuna_integration/sklearn.py b/optuna_integration/sklearn.py new file mode 100644 index 00000000..5135ffb2 --- /dev/null +++ b/optuna_integration/sklearn.py @@ -0,0 +1,948 @@ +from __future__ import annotations + +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Mapping +from logging import DEBUG +from logging import INFO +from logging import WARNING +from numbers import Integral +from numbers import Number +from time import time +from typing import Any +from typing import List +from typing import Union + +import numpy as np + +from optuna import distributions +from optuna import logging +from optuna import samplers +from optuna import study as study_module +from optuna import TrialPruned +from optuna._experimental import experimental_class +from optuna._imports import try_import +from optuna.distributions import _convert_old_distribution_to_new_distribution +from optuna.study import StudyDirection +from optuna.terminator import report_cross_validation_scores +from optuna.trial import FrozenTrial +from optuna.trial import Trial + + +with try_import() as _imports: + import pandas as pd + import scipy as sp + from scipy.sparse import spmatrix + import sklearn + from sklearn.base import BaseEstimator + from sklearn.base import clone + from sklearn.base import is_classifier + from sklearn.metrics import check_scoring + from sklearn.model_selection import BaseCrossValidator + from sklearn.model_selection import check_cv + from sklearn.model_selection import cross_validate + from sklearn.utils import _safe_indexing as sklearn_safe_indexing + from sklearn.utils import check_random_state + from sklearn.utils.metaestimators import _safe_split + from sklearn.utils.validation import check_is_fitted + +if not _imports.is_successful(): + BaseEstimator = object # NOQA + +ArrayLikeType = Union[List, np.ndarray, "pd.Series", "spmatrix"] +OneDimArrayLikeType = Union[List[float], np.ndarray, "pd.Series"] +TwoDimArrayLikeType = Union[List[List[float]], np.ndarray, "pd.DataFrame", "spmatrix"] +IterableType = Union[List, "pd.DataFrame", np.ndarray, "pd.Series", "spmatrix", None] +IndexableType = Union[Iterable, None] + +_logger = logging.get_logger(__name__) + + +def _check_fit_params( + X: TwoDimArrayLikeType, fit_params: dict, indices: OneDimArrayLikeType +) -> dict: + fit_params_validated = {} + for key, value in fit_params.items(): + # NOTE Original implementation: + # https://github.com/scikit-learn/scikit-learn/blob/ \ + # 2467e1b84aeb493a22533fa15ff92e0d7c05ed1c/sklearn/utils/validation.py#L1324-L1328 + # Scikit-learn does not accept non-iterable inputs. + # This line is for keeping backward compatibility. + # (See: https://github.com/scikit-learn/scikit-learn/issues/15805) + if not _is_arraylike(value) or _num_samples(value) != _num_samples(X): + fit_params_validated[key] = value + else: + fit_params_validated[key] = _make_indexable(value) + fit_params_validated[key] = _safe_indexing(fit_params_validated[key], indices) + return fit_params_validated + + +# NOTE Original implementation: +# https://github.com/scikit-learn/scikit-learn/blob/ \ +# 8caa93889f85254fc3ca84caa0a24a1640eebdd1/sklearn/utils/validation.py#L131-L135 +def _is_arraylike(x: Any) -> bool: + return hasattr(x, "__len__") or hasattr(x, "shape") or hasattr(x, "__array__") + + +# NOTE Original implementation: +# https://github.com/scikit-learn/scikit-learn/blob/ \ +# 8caa93889f85254fc3ca84caa0a24a1640eebdd1/sklearn/utils/validation.py#L217-L234 +def _make_indexable(iterable: IterableType) -> IndexableType: + tocsr_func = getattr(iterable, "tocsr", None) + if tocsr_func is not None and sp.sparse.issparse(iterable): + return tocsr_func(iterable) + elif hasattr(iterable, "__getitem__") or hasattr(iterable, "iloc"): + return iterable + elif iterable is None: + return iterable + return np.array(iterable) + + +def _num_samples(x: ArrayLikeType) -> int: + # NOTE For dask dataframes + # https://github.com/scikit-learn/scikit-learn/blob/ \ + # 8caa93889f85254fc3ca84caa0a24a1640eebdd1/sklearn/utils/validation.py#L155-L158 + x_shape = getattr(x, "shape", None) + if x_shape is not None: + if isinstance(x_shape[0], Integral): + return int(x_shape[0]) + + try: + return len(x) + except TypeError: + raise TypeError("Expected sequence or array-like, got %s." % type(x)) from None + + +def _safe_indexing( + X: OneDimArrayLikeType | TwoDimArrayLikeType, indices: OneDimArrayLikeType +) -> OneDimArrayLikeType | TwoDimArrayLikeType: + if X is None: + return X + + return sklearn_safe_indexing(X, indices) + + +class _Objective: + """Callable that implements objective function. + + Args: + estimator: + Object to use to fit the data. This is assumed to implement the + scikit-learn estimator interface. Either this needs to provide + ``score``, or ``scoring`` must be passed. + + param_distributions: + Dictionary where keys are parameters and values are distributions. + Distributions are assumed to implement the optuna distribution + interface. + + X: + Training data. + + y: + Target variable. + + cv: + Cross-validation strategy. + + enable_pruning: + If :obj:`True`, pruning is performed in the case where the + underlying estimator supports ``partial_fit``. + + error_score: + Value to assign to the score if an error occurs in fitting. If + 'raise', the error is raised. If numeric, + ``sklearn.exceptions.FitFailedWarning`` is raised. This does not + affect the refit step, which will always raise the error. + + fit_params: + Parameters passed to ``fit`` one the estimator. + + groups: + Group labels for the samples used while splitting the dataset into + train/validation set. + + max_iter: + Maximum number of epochs. This is only used if the underlying + estimator supports ``partial_fit``. + + return_train_score: + If :obj:`True`, training scores will be included. Computing + training scores is used to get insights on how different + hyperparameter settings impact the overfitting/underfitting + trade-off. However computing training scores can be + computationally expensive and is not strictly required to select + the hyperparameters that yield the best generalization + performance. + + scoring: + Scorer function. + """ + + def __init__( + self, + estimator: "sklearn.base.BaseEstimator", + param_distributions: Mapping[str, distributions.BaseDistribution], + X: TwoDimArrayLikeType, + y: OneDimArrayLikeType | TwoDimArrayLikeType | None, + cv: "BaseCrossValidator", + enable_pruning: bool, + error_score: Number | float | str, + fit_params: dict[str, Any], + groups: OneDimArrayLikeType | None, + max_iter: int, + return_train_score: bool, + scoring: Callable[..., Number], + ) -> None: + self.cv = cv + self.enable_pruning = enable_pruning + self.error_score = error_score + self.estimator = estimator + self.fit_params = fit_params + self.groups = groups + self.max_iter = max_iter + self.param_distributions = param_distributions + self.return_train_score = return_train_score + self.scoring = scoring + self.X = X + self.y = y + + def __call__(self, trial: Trial) -> float: + estimator = clone(self.estimator) + params = self._get_params(trial) + + estimator.set_params(**params) + + if self.enable_pruning: + scores = self._cross_validate_with_pruning(trial, estimator) + else: + try: + scores = cross_validate( + estimator, + self.X, + self.y, + cv=self.cv, + error_score=self.error_score, + fit_params=self.fit_params, + groups=self.groups, + return_train_score=self.return_train_score, + scoring=self.scoring, + ) + except ValueError: + n_splits = self.cv.get_n_splits(self.X, self.y, self.groups) + fit_time = np.array([np.nan] * n_splits) + score_time = np.array([np.nan] * n_splits) + test_score = np.array( + [self.error_score if self.error_score is not None else np.nan] * n_splits + ) + + scores = { + "fit_time": fit_time, + "score_time": score_time, + "test_score": test_score, + } + + self._store_scores(trial, scores) + + test_scores = scores["test_score"] + scores_list = test_scores if isinstance(test_scores, list) else test_scores.tolist() + report_cross_validation_scores(trial, scores_list) + + return trial.user_attrs["mean_test_score"] + + def _cross_validate_with_pruning( + self, trial: Trial, estimator: "sklearn.base.BaseEstimator" + ) -> Mapping[str, OneDimArrayLikeType]: + if is_classifier(estimator): + partial_fit_params = self.fit_params.copy() + y = self.y.values if isinstance(self.y, pd.Series) else self.y + classes = np.unique(y) + + partial_fit_params.setdefault("classes", classes) + + else: + partial_fit_params = self.fit_params + + n_splits = self.cv.get_n_splits(self.X, self.y, groups=self.groups) + estimators = [clone(estimator) for _ in range(n_splits)] + scores = { + "fit_time": np.zeros(n_splits), + "score_time": np.zeros(n_splits), + "test_score": np.empty(n_splits), + } + + if self.return_train_score: + scores["train_score"] = np.empty(n_splits) + + for step in range(self.max_iter): + for i, (train, test) in enumerate(self.cv.split(self.X, self.y, groups=self.groups)): + out = self._partial_fit_and_score(estimators[i], train, test, partial_fit_params) + + if self.return_train_score: + scores["train_score"][i] = out.pop(0) + + scores["test_score"][i] = out[0] + scores["fit_time"][i] += out[1] + scores["score_time"][i] += out[2] + + intermediate_value = np.nanmean(scores["test_score"]) + + trial.report(intermediate_value, step=step) + + if trial.should_prune(): + self._store_scores(trial, scores) + + raise TrialPruned("trial was pruned at iteration {}.".format(step)) + + return scores + + def _get_params(self, trial: Trial) -> dict[str, Any]: + return { + name: trial._suggest(name, distribution) + for name, distribution in self.param_distributions.items() + } + + def _partial_fit_and_score( + self, + estimator: "sklearn.base.BaseEstimator", + train: list[int], + test: list[int], + partial_fit_params: dict[str, Any], + ) -> list[Number]: + X_train, y_train = _safe_split(estimator, self.X, self.y, train) + X_test, y_test = _safe_split(estimator, self.X, self.y, test, train_indices=train) + + start_time = time() + + try: + estimator.partial_fit(X_train, y_train, **partial_fit_params) + + except Exception as e: + if self.error_score == "raise": + raise e + + elif isinstance(self.error_score, Number): + fit_time = time() - start_time + test_score = self.error_score + score_time = 0.0 + + if self.return_train_score: + train_score = self.error_score + + else: + raise ValueError("error_score must be 'raise' or numeric.") from e + + else: + fit_time = time() - start_time + test_score = self.scoring(estimator, X_test, y_test) + score_time = time() - fit_time - start_time + + if self.return_train_score: + train_score = self.scoring(estimator, X_train, y_train) + + # Required for type checking but is never expected to fail. + assert isinstance(fit_time, Number) + assert isinstance(score_time, Number) + + ret = [test_score, fit_time, score_time] + + if self.return_train_score: + ret.insert(0, train_score) + + return ret + + def _store_scores(self, trial: Trial, scores: Mapping[str, OneDimArrayLikeType]) -> None: + for name, array in scores.items(): + if name in ["test_score", "train_score"]: + for i, score in enumerate(array): + trial.set_user_attr("split{}_{}".format(i, name), score) + + trial.set_user_attr("mean_{}".format(name), np.nanmean(array)) + trial.set_user_attr("std_{}".format(name), np.nanstd(array)) + + +@experimental_class("0.17.0") +class OptunaSearchCV(BaseEstimator): + """Hyperparameter search with cross-validation. + + Args: + estimator: + Object to use to fit the data. This is assumed to implement the + scikit-learn estimator interface. Either this needs to provide + ``score``, or ``scoring`` must be passed. + + param_distributions: + Dictionary where keys are parameters and values are distributions. + Distributions are assumed to implement the optuna distribution + interface. + + cv: + Cross-validation strategy. Possible inputs for cv are: + + - :obj:`None`, to use the default 5-fold cross validation, + - integer to specify the number of folds in a CV splitter, + - `CV splitter `_, + - an iterable yielding (train, validation) splits as arrays of indices. + + For integer, if ``estimator`` is a classifier and ``y`` is + either binary or multiclass, + ``sklearn.model_selection.StratifiedKFold`` is used. otherwise, + ``sklearn.model_selection.KFold`` is used. + + enable_pruning: + If :obj:`True`, pruning is performed in the case where the + underlying estimator supports ``partial_fit``. + + error_score: + Value to assign to the score if an error occurs in fitting. If + 'raise', the error is raised. If numeric, + ``sklearn.exceptions.FitFailedWarning`` is raised. This does not + affect the refit step, which will always raise the error. + + max_iter: + Maximum number of epochs. This is only used if the underlying + estimator supports ``partial_fit``. + + n_jobs: + Number of :obj:`threading` based parallel jobs. :obj:`None` means ``1``. + ``-1`` means using the number is set to CPU count. + + .. note:: + ``n_jobs`` allows parallelization using :obj:`threading` and may suffer from + `Python's GIL `_. + It is recommended to use :ref:`process-based parallelization` + if ``func`` is CPU bound. + + n_trials: + Number of trials. If :obj:`None`, there is no limitation on the + number of trials. If ``timeout`` is also set to :obj:`None`, + the study continues to create trials until it receives a + termination signal such as Ctrl+C or SIGTERM. This trades off + runtime vs quality of the solution. + + random_state: + Seed of the pseudo random number generator. If int, this is the + seed used by the random number generator. If + ``numpy.random.RandomState`` object, this is the random number + generator. If :obj:`None`, the global random state from + ``numpy.random`` is used. + + refit: + If :obj:`True`, refit the estimator with the best found + hyperparameters. The refitted estimator is made available at the + ``best_estimator_`` attribute and permits using ``predict`` + directly. + + return_train_score: + If :obj:`True`, training scores will be included. Computing + training scores is used to get insights on how different + hyperparameter settings impact the overfitting/underfitting + trade-off. However computing training scores can be + computationally expensive and is not strictly required to select + the hyperparameters that yield the best generalization + performance. + + scoring: + String or callable to evaluate the predictions on the validation data. + If :obj:`None`, ``score`` on the estimator is used. + + study: + Study corresponds to the optimization task. If :obj:`None`, a new + study is created. + + subsample: + Proportion of samples that are used during hyperparameter search. + + - If int, then draw ``subsample`` samples. + - If float, then draw ``subsample`` * ``X.shape[0]`` samples. + + timeout: + Time limit in seconds for the search of appropriate models. If + :obj:`None`, the study is executed without time limitation. If + ``n_trials`` is also set to :obj:`None`, the study continues to + create trials until it receives a termination signal such as + Ctrl+C or SIGTERM. This trades off runtime vs quality of the + solution. + + verbose: + Verbosity level. The higher, the more messages. + + callbacks: + List of callback functions that are invoked at the end of each trial. Each function + must accept two parameters with the following types in this order: + :class:`~optuna.study.Study` and :class:`~optuna.trial.FrozenTrial`. + + .. seealso:: + + See the tutorial of :ref:`optuna_callback` for how to use and implement + callback functions. + + Attributes: + best_estimator_: + Estimator that was chosen by the search. This is present only if + ``refit`` is set to :obj:`True`. + + n_splits_: + Number of cross-validation splits. + + refit_time_: + Time for refitting the best estimator. This is present only if + ``refit`` is set to :obj:`True`. + + sample_indices_: + Indices of samples that are used during hyperparameter search. + + scorer_: + Scorer function. + + study_: + Actual study. + + Examples: + + .. testcode:: + + import optuna + from sklearn.datasets import load_iris + from sklearn.svm import SVC + + clf = SVC(gamma="auto") + param_distributions = { + "C": optuna.distributions.FloatDistribution(1e-10, 1e10, log=True) + } + optuna_search = optuna.integration.OptunaSearchCV(clf, param_distributions) + X, y = load_iris(return_X_y=True) + optuna_search.fit(X, y) + y_pred = optuna_search.predict(X) + + .. note:: + By following the scikit-learn convention for scorers, the direction of optimization is + ``maximize``. See https://scikit-learn.org/stable/modules/model_evaluation.html. + For the minimization problem, please multiply ``-1``. + """ + + _required_parameters = ["estimator", "param_distributions"] + + @property + def _estimator_type(self) -> str: + return self.estimator._estimator_type + + @property + def best_index_(self) -> int: + """Trial number which corresponds to the best candidate parameter setting. + + Returned value is equivalent to ``optuna_search.best_trial_.number``. + """ + + return self.best_trial_.number + + @property + def best_params_(self) -> dict[str, Any]: + """Parameters of the best trial in the :class:`~optuna.study.Study`.""" + + self._check_is_fitted() + + return self.study_.best_params + + @property + def best_score_(self) -> float: + """Mean cross-validated score of the best estimator.""" + + self._check_is_fitted() + + return self.study_.best_value + + @property + def best_trial_(self) -> FrozenTrial: + """Best trial in the :class:`~optuna.study.Study`.""" + + self._check_is_fitted() + + return self.study_.best_trial + + @property + def classes_(self) -> OneDimArrayLikeType: + """Class labels.""" + + self._check_is_fitted() + + return self.best_estimator_.classes_ + + @property + def cv_results_(self) -> dict[str, Any]: + """A dictionary mapping a metric name to a list of + Cross-Validation results of all trials.""" + cv_results_dict_in_list = [trial_.user_attrs for trial_ in self.trials_] + if len(cv_results_dict_in_list) == 0: + cv_results_list_in_dict = {} + else: + cv_results_list_in_dict = { + key: [dict_[key] for dict_ in cv_results_dict_in_list] + for key in cv_results_dict_in_list[0] + } + return cv_results_list_in_dict + + @property + def n_trials_(self) -> int: + """Actual number of trials.""" + + return len(self.trials_) + + @property + def trials_(self) -> list[FrozenTrial]: + """All trials in the :class:`~optuna.study.Study`.""" + + self._check_is_fitted() + + return self.study_.trials + + @property + def user_attrs_(self) -> dict[str, Any]: + """User attributes in the :class:`~optuna.study.Study`.""" + + self._check_is_fitted() + + return self.study_.user_attrs + + @property + def decision_function(self) -> Callable[..., OneDimArrayLikeType | TwoDimArrayLikeType]: + """Call ``decision_function`` on the best estimator. + + This is available only if the underlying estimator supports + ``decision_function`` and ``refit`` is set to :obj:`True`. + """ + + self._check_is_fitted() + + return self.best_estimator_.decision_function + + @property + def inverse_transform(self) -> Callable[..., TwoDimArrayLikeType]: + """Call ``inverse_transform`` on the best estimator. + + This is available only if the underlying estimator supports + ``inverse_transform`` and ``refit`` is set to :obj:`True`. + """ + + self._check_is_fitted() + + return self.best_estimator_.inverse_transform + + @property + def predict(self) -> Callable[..., OneDimArrayLikeType | TwoDimArrayLikeType]: + """Call ``predict`` on the best estimator. + + This is available only if the underlying estimator supports ``predict`` + and ``refit`` is set to :obj:`True`. + """ + + self._check_is_fitted() + + return self.best_estimator_.predict + + @property + def predict_log_proba(self) -> Callable[..., TwoDimArrayLikeType]: + """Call ``predict_log_proba`` on the best estimator. + + This is available only if the underlying estimator supports + ``predict_log_proba`` and ``refit`` is set to :obj:`True`. + """ + + self._check_is_fitted() + + return self.best_estimator_.predict_log_proba + + @property + def predict_proba(self) -> Callable[..., TwoDimArrayLikeType]: + """Call ``predict_proba`` on the best estimator. + + This is available only if the underlying estimator supports + ``predict_proba`` and ``refit`` is set to :obj:`True`. + """ + + self._check_is_fitted() + + return self.best_estimator_.predict_proba + + @property + def score_samples(self) -> Callable[..., OneDimArrayLikeType]: + """Call ``score_samples`` on the best estimator. + + This is available only if the underlying estimator supports + ``score_samples`` and ``refit`` is set to :obj:`True`. + """ + + self._check_is_fitted() + + return self.best_estimator_.score_samples + + @property + def set_user_attr(self) -> Callable[..., None]: + """Call ``set_user_attr`` on the :class:`~optuna.study.Study`.""" + + self._check_is_fitted() + + return self.study_.set_user_attr + + @property + def transform(self) -> Callable[..., TwoDimArrayLikeType]: + """Call ``transform`` on the best estimator. + + This is available only if the underlying estimator supports + ``transform`` and ``refit`` is set to :obj:`True`. + """ + + self._check_is_fitted() + + return self.best_estimator_.transform + + @property + def trials_dataframe(self) -> Callable[..., "pd.DataFrame"]: + """Call ``trials_dataframe`` on the :class:`~optuna.study.Study`.""" + + self._check_is_fitted() + + return self.study_.trials_dataframe + + def __init__( + self, + estimator: "sklearn.base.BaseEstimator", + param_distributions: Mapping[str, distributions.BaseDistribution], + *, + cv: int | "BaseCrossValidator" | Iterable | None = None, + enable_pruning: bool = False, + error_score: Number | float | str = np.nan, + max_iter: int = 1000, + n_jobs: int | None = None, + n_trials: int | None = 10, + random_state: int | np.random.RandomState | None = None, + refit: bool = True, + return_train_score: bool = False, + scoring: Callable[..., float] | str | None = None, + study: study_module.Study | None = None, + subsample: float | int = 1.0, + timeout: float | None = None, + verbose: int = 0, + callbacks: list[Callable[[study_module.Study, FrozenTrial], None]] | None = None, + ) -> None: + _imports.check() + + if not isinstance(param_distributions, dict): + raise TypeError("param_distributions must be a dictionary.") + + # Rejecting deprecated distributions as they may cause cryptic error + # when cloning OptunaSearchCV instance. + # https://github.com/optuna/optuna/issues/4084 + for key, dist in param_distributions.items(): + if dist != _convert_old_distribution_to_new_distribution(dist): + raise ValueError( + f"Deprecated distribution is specified in `{key}` of param_distributions. " + "Rejecting this because it may cause unexpected behavior. " + "Please use new distributions such as FloatDistribution etc." + ) + + self.cv = cv + self.enable_pruning = enable_pruning + self.error_score = error_score + self.estimator = estimator + self.max_iter = max_iter + self.n_trials = n_trials + self.n_jobs = n_jobs if n_jobs else 1 + self.param_distributions = param_distributions + self.random_state = random_state + self.refit = refit + self.return_train_score = return_train_score + self.scoring = scoring + self.study = study + self.subsample = subsample + self.timeout = timeout + self.verbose = verbose + self.callbacks = callbacks + + def _check_is_fitted(self) -> None: + attributes = ["n_splits_", "sample_indices_", "scorer_", "study_"] + + if self.refit: + attributes += ["best_estimator_", "refit_time_"] + + check_is_fitted(self, attributes) + + def _check_params(self) -> None: + if not hasattr(self.estimator, "fit"): + raise ValueError("estimator must be a scikit-learn estimator.") + + for name, distribution in self.param_distributions.items(): + if not isinstance(distribution, distributions.BaseDistribution): + raise ValueError("Value of {} must be a optuna distribution.".format(name)) + + if self.enable_pruning and not hasattr(self.estimator, "partial_fit"): + raise ValueError("estimator must support partial_fit.") + + if self.max_iter <= 0: + raise ValueError("max_iter must be > 0, got {}.".format(self.max_iter)) + + if self.study is not None and self.study.direction != StudyDirection.MAXIMIZE: + raise ValueError("direction of study must be 'maximize'.") + + def _more_tags(self) -> dict[str, bool]: + return {"non_deterministic": True, "no_validation": True} + + def _refit( + self, + X: TwoDimArrayLikeType, + y: OneDimArrayLikeType | TwoDimArrayLikeType | None = None, + **fit_params: Any, + ) -> "OptunaSearchCV": + n_samples = _num_samples(X) + + self.best_estimator_ = clone(self.estimator) + + try: + self.best_estimator_.set_params(**self.study_.best_params) + except ValueError as e: + _logger.exception(e) + + _logger.info("Refitting the estimator using {} samples...".format(n_samples)) + + start_time = time() + + self.best_estimator_.fit(X, y, **fit_params) + + self.refit_time_ = time() - start_time + + _logger.info("Finished refitting! (elapsed time: {:.3f} sec.)".format(self.refit_time_)) + + return self + + def fit( + self, + X: TwoDimArrayLikeType, + y: OneDimArrayLikeType | TwoDimArrayLikeType | None = None, + groups: OneDimArrayLikeType | None = None, + **fit_params: Any, + ) -> "OptunaSearchCV": + """Run fit with all sets of parameters. + + Args: + X: + Training data. + + y: + Target variable. + + groups: + Group labels for the samples used while splitting the dataset + into train/validation set. + + **fit_params: + Parameters passed to ``fit`` on the estimator. + + Returns: + self. + """ + + self._check_params() + + random_state = check_random_state(self.random_state) + max_samples = self.subsample + n_samples = _num_samples(X) + old_level = _logger.getEffectiveLevel() + + if self.verbose > 1: + _logger.setLevel(DEBUG) + elif self.verbose > 0: + _logger.setLevel(INFO) + else: + _logger.setLevel(WARNING) + + self.sample_indices_ = np.arange(n_samples) + + if type(max_samples) is float: + max_samples = int(max_samples * n_samples) + + if max_samples < n_samples: + self.sample_indices_ = random_state.choice( + self.sample_indices_, max_samples, replace=False + ) + + self.sample_indices_.sort() + + X_res = _safe_indexing(X, self.sample_indices_) + y_res = _safe_indexing(y, self.sample_indices_) + groups_res = _safe_indexing(groups, self.sample_indices_) + fit_params_res = fit_params + + if fit_params_res is not None: + fit_params_res = _check_fit_params(X, fit_params, self.sample_indices_) + + classifier = is_classifier(self.estimator) + cv = check_cv(self.cv, y_res, classifier=classifier) + + self.n_splits_ = cv.get_n_splits(X_res, y_res, groups=groups_res) + self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) + + if self.study is None: + seed = random_state.randint(0, np.iinfo("int32").max) + sampler = samplers.TPESampler(seed=seed) + + self.study_ = study_module.create_study(direction="maximize", sampler=sampler) + + else: + self.study_ = self.study + + objective = _Objective( + self.estimator, + self.param_distributions, + X_res, + y_res, + cv, + self.enable_pruning, + self.error_score, + fit_params_res, + groups_res, + self.max_iter, + self.return_train_score, + self.scorer_, + ) + + _logger.info( + "Searching the best hyperparameters using {} " + "samples...".format(_num_samples(self.sample_indices_)) + ) + + self.study_.optimize( + objective, + n_jobs=self.n_jobs, + n_trials=self.n_trials, + timeout=self.timeout, + callbacks=self.callbacks, + ) + + _logger.info("Finished hyperparameter search!") + + if self.refit: + self._refit(X, y, **fit_params) + + _logger.setLevel(old_level) + + return self + + def score( + self, + X: TwoDimArrayLikeType, + y: OneDimArrayLikeType | TwoDimArrayLikeType | None = None, + ) -> float: + """Return the score on the given data. + + Args: + X: + Data. + + y: + Target variable. + + Returns: + Scaler score. + """ + + return self.scorer_(self.best_estimator_, X, y) From e7e047eeedf366074cbe8bdcc545ba49e1495119 Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 1 Feb 2024 20:38:14 +0900 Subject: [PATCH 03/11] Fetched from https://github.com/optuna/optuna/commit/b4dd96021cbaca27d9549438d04f85bf9d93ecb3. --- tests/test_sklearn.py | 432 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 432 insertions(+) create mode 100644 tests/test_sklearn.py diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py new file mode 100644 index 00000000..d52bf898 --- /dev/null +++ b/tests/test_sklearn.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +from unittest.mock import MagicMock +from unittest.mock import patch +import warnings + +import numpy as np +import pytest +import scipy as sp +from sklearn.datasets import make_blobs +from sklearn.datasets import make_regression +from sklearn.decomposition import PCA +from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import SGDClassifier +from sklearn.neighbors import KernelDensity +from sklearn.tree import DecisionTreeRegressor + +from optuna import distributions +from optuna import integration +from optuna.samplers import BruteForceSampler +from optuna.study import create_study +from optuna.terminator.erroreval import _CROSS_VALIDATION_SCORES_KEY + + +pytestmark = pytest.mark.integration + + +def test_is_arraylike() -> None: + assert integration.sklearn._is_arraylike([]) + assert integration.sklearn._is_arraylike(np.zeros(5)) + assert not integration.sklearn._is_arraylike(1) + + +def test_num_samples() -> None: + x1 = np.random.random((10, 10)) + x2 = [1, 2, 3] + assert integration.sklearn._num_samples(x1) == 10 + assert integration.sklearn._num_samples(x2) == 3 + + +def test_make_indexable() -> None: + x1 = np.random.random((10, 10)) + x2 = sp.sparse.coo_matrix(x1) + x3 = [1, 2, 3] + + assert hasattr(integration.sklearn._make_indexable(x1), "__getitem__") + assert hasattr(integration.sklearn._make_indexable(x2), "__getitem__") + assert hasattr(integration.sklearn._make_indexable(x3), "__getitem__") + assert integration.sklearn._make_indexable(None) is None + + +@pytest.mark.parametrize("enable_pruning", [True, False]) +@pytest.mark.parametrize("fit_params", ["", "coef_init"]) +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_optuna_search(enable_pruning: bool, fit_params: str) -> None: + X, y = make_blobs(n_samples=10) + est = SGDClassifier(max_iter=5, tol=1e-03) + param_dist = {"alpha": distributions.FloatDistribution(1e-04, 1e03, log=True)} + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + enable_pruning=enable_pruning, + error_score="raise", + max_iter=5, + random_state=0, + return_train_score=True, + ) + + with pytest.raises(NotFittedError): + optuna_search._check_is_fitted() + + if fit_params == "coef_init" and not enable_pruning: + optuna_search.fit(X, y, coef_init=np.ones((3, 2), dtype=np.float64)) + else: + optuna_search.fit(X, y) + + optuna_search.trials_dataframe() + optuna_search.decision_function(X) + optuna_search.predict(X) + optuna_search.score(X, y) + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_optuna_search_properties() -> None: + X, y = make_blobs(n_samples=10) + est = LogisticRegression(tol=1e-03) + param_dist = {"C": distributions.FloatDistribution(1e-04, 1e03, log=True)} + + optuna_search = integration.OptunaSearchCV( + est, param_dist, cv=3, error_score="raise", random_state=0, return_train_score=True + ) + optuna_search.fit(X, y) + optuna_search.set_user_attr("dataset", "blobs") + + assert optuna_search._estimator_type == "classifier" + assert isinstance(optuna_search.best_index_, int) + assert isinstance(optuna_search.best_params_, dict) + assert isinstance(optuna_search.cv_results_, dict) + for cv_result_list_ in optuna_search.cv_results_.values(): + assert len(cv_result_list_) == optuna_search.n_trials_ + assert optuna_search.best_score_ is not None + assert optuna_search.best_trial_ is not None + assert np.allclose(optuna_search.classes_, np.array([0, 1, 2])) + assert optuna_search.n_trials_ == 10 + assert optuna_search.user_attrs_ == {"dataset": "blobs"} + assert type(optuna_search.predict_log_proba(X)) == np.ndarray + assert type(optuna_search.predict_proba(X)) == np.ndarray + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_optuna_search_score_samples() -> None: + X, y = make_blobs(n_samples=10) + est = KernelDensity() + optuna_search = integration.OptunaSearchCV( + est, {}, cv=3, error_score="raise", random_state=0, return_train_score=True + ) + optuna_search.fit(X) + assert optuna_search.score_samples(X) is not None + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_optuna_search_transforms() -> None: + X, y = make_blobs(n_samples=10) + est = PCA() + optuna_search = integration.OptunaSearchCV( + est, {}, cv=3, error_score="raise", random_state=0, return_train_score=True + ) + optuna_search.fit(X) + assert type(optuna_search.transform(X)) == np.ndarray + assert type(optuna_search.inverse_transform(X)) == np.ndarray + + +def test_optuna_search_invalid_estimator() -> None: + X, y = make_blobs(n_samples=10) + est = "not an estimator" + optuna_search = integration.OptunaSearchCV( + est, {}, cv=3, error_score="raise", random_state=0, return_train_score=True + ) + + with pytest.raises(ValueError, match="estimator must be a scikit-learn estimator."): + optuna_search.fit(X) + + +def test_optuna_search_pruning_without_partial_fit() -> None: + X, y = make_blobs(n_samples=10) + est = KernelDensity() + param_dist = {} # type: ignore + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + enable_pruning=True, + error_score="raise", + random_state=0, + return_train_score=True, + ) + + with pytest.raises(ValueError, match="estimator must support partial_fit."): + optuna_search.fit(X) + + +def test_optuna_search_negative_max_iter() -> None: + X, y = make_blobs(n_samples=10) + est = KernelDensity() + param_dist = {} # type: ignore + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + max_iter=-1, + error_score="raise", + random_state=0, + return_train_score=True, + ) + + with pytest.raises(ValueError, match="max_iter must be > 0"): + optuna_search.fit(X) + + +def test_optuna_search_tuple_instead_of_distribution() -> None: + X, y = make_blobs(n_samples=10) + est = KernelDensity() + param_dist = {"kernel": ("gaussian", "linear")} + optuna_search = integration.OptunaSearchCV( + est, + param_dist, # type: ignore + cv=3, + error_score="raise", + random_state=0, + return_train_score=True, + ) + + with pytest.raises(ValueError, match="must be a optuna distribution."): + optuna_search.fit(X) + + +def test_optuna_search_study_with_minimize() -> None: + X, y = make_blobs(n_samples=10) + est = KernelDensity() + study = create_study(direction="minimize") + optuna_search = integration.OptunaSearchCV( + est, {}, cv=3, error_score="raise", random_state=0, return_train_score=True, study=study + ) + + with pytest.raises(ValueError, match="direction of study must be 'maximize'."): + optuna_search.fit(X) + + +@pytest.mark.parametrize("verbose", [1, 2]) +def test_optuna_search_verbosity(verbose: int) -> None: + X, y = make_blobs(n_samples=10) + est = KernelDensity() + param_dist = {} # type: ignore + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + error_score="raise", + random_state=0, + return_train_score=True, + verbose=verbose, + ) + optuna_search.fit(X) + + +def test_optuna_search_subsample() -> None: + X, y = make_blobs(n_samples=10) + est = KernelDensity() + param_dist = {} # type: ignore + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + error_score="raise", + random_state=0, + return_train_score=True, + subsample=5, + ) + optuna_search.fit(X) + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_objective_y_None() -> None: + X, y = make_blobs(n_samples=10) + est = SGDClassifier(max_iter=5, tol=1e-03) + param_dist = {} # type: ignore + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + enable_pruning=True, + error_score="raise", + random_state=0, + return_train_score=True, + ) + + with pytest.raises(ValueError): + optuna_search.fit(X) + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_objective_error_score_nan() -> None: + X, y = make_blobs(n_samples=10) + est = SGDClassifier(max_iter=5, tol=1e-03) + param_dist = {} # type: ignore + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + enable_pruning=True, + max_iter=5, + error_score=np.nan, + random_state=0, + return_train_score=True, + ) + + with pytest.raises( + ValueError, + match="This SGDClassifier estimator requires y to be passed, but the target y is None.", + ): + optuna_search.fit(X) + + for trial in optuna_search.study_.get_trials(): + assert np.all(np.isnan(list(trial.intermediate_values.values()))) + + # "_score" stores every score value for train and test validation holds. + for name, value in trial.user_attrs.items(): + if name.endswith("_score"): + assert np.isnan(value) + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_objective_error_score_invalid() -> None: + X, y = make_blobs(n_samples=10) + est = SGDClassifier(max_iter=5, tol=1e-03) + param_dist = {} # type: ignore + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + enable_pruning=True, + max_iter=5, + error_score="invalid error score", + random_state=0, + return_train_score=True, + ) + + with pytest.raises(ValueError, match="error_score must be 'raise' or numeric."): + optuna_search.fit(X) + + +# This test checks whether OptunaSearchCV completes the study without halting, even if some trials +# fails due to misconfiguration. +@pytest.mark.parametrize( + "param_dist,all_params", + [ + ({"max_depth": distributions.IntDistribution(0, 1)}, [0, 1]), + ({"max_depth": distributions.IntDistribution(0, 0)}, [0]), + ], +) +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_no_halt_with_error( + param_dist: dict[str, distributions.BaseDistribution], all_params: list[int] +) -> None: + X, y = make_regression(n_samples=100, n_features=10) + estimator = DecisionTreeRegressor() + study = create_study(sampler=BruteForceSampler(), direction="maximize") + + # DecisionTreeRegressor raises ValueError when max_depth==0. + optuna_search = integration.OptunaSearchCV( + estimator, + param_dist, + study=study, + ) + optuna_search.fit(X, y) + all_suggested_values = [t.params["max_depth"] for t in study.trials] + assert len(all_suggested_values) == len(all_params) + for a in all_params: + assert a in all_suggested_values + + +# TODO(himkt): Remove this method with the deletion of deprecated distributions. +# https://github.com/optuna/optuna/issues/2941 +@pytest.mark.filterwarnings("ignore::FutureWarning") +def test_optuna_search_convert_deprecated_distribution() -> None: + param_dist = { + "ud": distributions.UniformDistribution(low=0, high=10), + "dud": distributions.DiscreteUniformDistribution(low=0, high=10, q=2), + "lud": distributions.LogUniformDistribution(low=1, high=10), + "id": distributions.IntUniformDistribution(low=0, high=10), + "idd": distributions.IntUniformDistribution(low=0, high=10, step=2), + "ild": distributions.IntLogUniformDistribution(low=1, high=10), + } + + expected_param_dist = { + "ud": distributions.FloatDistribution(low=0, high=10, log=False, step=None), + "dud": distributions.FloatDistribution(low=0, high=10, log=False, step=2), + "lud": distributions.FloatDistribution(low=1, high=10, log=True, step=None), + "id": distributions.IntDistribution(low=0, high=10, log=False, step=1), + "idd": distributions.IntDistribution(low=0, high=10, log=False, step=2), + "ild": distributions.IntDistribution(low=1, high=10, log=True, step=1), + } + + with pytest.raises(ValueError): + optuna_search = integration.OptunaSearchCV( + KernelDensity(), + param_dist, + ) + + # It confirms that ask doesn't convert non-deprecated distributions. + optuna_search = integration.OptunaSearchCV( + KernelDensity(), + expected_param_dist, + ) + + assert optuna_search.param_distributions == expected_param_dist + + +def test_callbacks() -> None: + callbacks = [] + + for _ in range(2): + callback = MagicMock() + callback.__call__ = MagicMock(return_value=None) # type: ignore + callbacks.append(callback) + + n_trials = 5 + X, y = make_blobs(n_samples=10) + est = SGDClassifier(max_iter=5, tol=1e-03) + param_dist = {"alpha": distributions.FloatDistribution(1e-04, 1e03, log=True)} + optuna_search = integration.OptunaSearchCV( + est, + param_dist, + cv=3, + enable_pruning=True, + max_iter=5, + n_trials=n_trials, + error_score=np.nan, + callbacks=callbacks, # type: ignore + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=ConvergenceWarning) + optuna_search.fit(X, y) + + for callback in callbacks: + for trial in optuna_search.trials_: + callback.assert_any_call(optuna_search.study_, trial) + assert callback.call_count == n_trials + + +@pytest.mark.filterwarnings("ignore::UserWarning") +@patch("optuna.integration.sklearn.cross_validate") +def test_terminator_cv_score_reporting(mock: MagicMock) -> None: + scores = { + "fit_time": np.array([2.01, 1.78, 3.22]), + "score_time": np.array([0.33, 0.35, 0.48]), + "test_score": np.array([0.04, 0.80, 0.70]), + } + mock.return_value = scores + + X, _ = make_blobs(n_samples=10) + est = PCA() + optuna_search = integration.OptunaSearchCV(est, {}, cv=3, error_score="raise", random_state=0) + optuna_search.fit(X) + + for trial in optuna_search.study_.trials: + assert (trial.system_attrs[_CROSS_VALIDATION_SCORES_KEY] == scores["test_score"]).all() From d62d80871a1dcbb53e5f46e2ce5964a88b69c431 Mon Sep 17 00:00:00 2001 From: gen740 Date: Fri, 2 Feb 2024 14:15:58 +0900 Subject: [PATCH 04/11] Fetched from https://github.com/optuna/optuna/commit/2f27115324a9a543ec4dda50d9328ce31531855b --- optuna_integration/dask.py | 750 +++++++++++++++++++++++++++++++++++++ 1 file changed, 750 insertions(+) create mode 100644 optuna_integration/dask.py diff --git a/optuna_integration/dask.py b/optuna_integration/dask.py new file mode 100644 index 00000000..e6367b2f --- /dev/null +++ b/optuna_integration/dask.py @@ -0,0 +1,750 @@ +import asyncio +from datetime import datetime +from typing import Any +from typing import Container +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +import uuid + +import optuna +from optuna._experimental import experimental_class +from optuna._imports import try_import +from optuna._typing import JSONSerializable +from optuna.distributions import BaseDistribution +from optuna.distributions import distribution_to_json +from optuna.distributions import json_to_distribution +from optuna.storages import BaseStorage +from optuna.study import StudyDirection +from optuna.study._frozen import FrozenStudy +from optuna.trial import FrozenTrial +from optuna.trial import TrialState + + +with try_import() as _imports: + import distributed + from distributed.protocol.pickle import dumps + from distributed.protocol.pickle import loads + from distributed.utils import thread_state # type: ignore[attr-defined] + from distributed.worker import get_client + + +def _serialize_frozentrial(trial: FrozenTrial) -> dict: + data = trial.__dict__.copy() + data["state"] = data["state"].name + attrs = [a for a in data.keys() if a.startswith("_")] + for attr in attrs: + data[attr[1:]] = data.pop(attr) + data["system_attrs"] = ( + dumps(data["system_attrs"]) # type: ignore[no-untyped-call] + if data["system_attrs"] + else {} + ) + data["user_attrs"] = ( + dumps(data["user_attrs"]) if data["user_attrs"] else {} # type: ignore[no-untyped-call] + ) + data["distributions"] = {k: distribution_to_json(v) for k, v in data["distributions"].items()} + if data["datetime_start"] is not None: + data["datetime_start"] = data["datetime_start"].isoformat(timespec="microseconds") + if data["datetime_complete"] is not None: + data["datetime_complete"] = data["datetime_complete"].isoformat(timespec="microseconds") + data["value"] = None + return data + + +def _deserialize_frozentrial(data: dict) -> FrozenTrial: + data["state"] = TrialState[data["state"]] + data["distributions"] = {k: json_to_distribution(v) for k, v in data["distributions"].items()} + if data["datetime_start"] is not None: + data["datetime_start"] = datetime.fromisoformat(data["datetime_start"]) + if data["datetime_complete"] is not None: + data["datetime_complete"] = datetime.fromisoformat(data["datetime_complete"]) + data["system_attrs"] = ( + loads(data["system_attrs"]) # type: ignore[no-untyped-call] + if data["system_attrs"] + else {} + ) + data["user_attrs"] = ( + loads(data["user_attrs"]) if data["user_attrs"] else {} # type: ignore[no-untyped-call] + ) + return FrozenTrial(**data) + + +def _serialize_frozenstudy(study: FrozenStudy) -> dict: + data = { + "directions": [d.name for d in study._directions], + "study_id": study._study_id, + "study_name": study.study_name, + "user_attrs": dumps(study.user_attrs) # type: ignore[no-untyped-call] + if study.user_attrs + else {}, + "system_attrs": dumps(study.system_attrs) # type: ignore[no-untyped-call] + if study.system_attrs + else {}, + } + return data + + +def _deserialize_frozenstudy(data: dict) -> FrozenStudy: + data["directions"] = [StudyDirection[d] for d in data["directions"]] + data["direction"] = None + data["system_attrs"] = ( + loads(data["system_attrs"]) # type: ignore[no-untyped-call] + if data["system_attrs"] + else {} + ) + data["user_attrs"] = ( + loads(data["user_attrs"]) if data["user_attrs"] else {} # type: ignore[no-untyped-call] + ) + return FrozenStudy(**data) + + +class _OptunaSchedulerExtension: + def __init__(self, scheduler: "distributed.Scheduler"): + self.scheduler = scheduler + self.storages: Dict[str, BaseStorage] = {} + + methods = [ + "create_new_study", + "delete_study", + "set_study_user_attr", + "set_study_system_attr", + "get_study_id_from_name", + "get_study_name_from_id", + "get_study_directions", + "get_study_user_attrs", + "get_study_system_attrs", + "get_all_studies", + "create_new_trial", + "set_trial_param", + "get_trial_id_from_study_id_trial_number", + "get_trial_number_from_id", + "get_trial_param", + "set_trial_state_values", + "set_trial_intermediate_value", + "set_trial_user_attr", + "set_trial_system_attr", + "get_trial", + "get_all_trials", + "get_n_trials", + ] + handlers = {f"optuna_{method}": getattr(self, method) for method in methods} + self.scheduler.handlers.update(handlers) + + self.scheduler.extensions["optuna"] = self + + def get_storage(self, name: str) -> BaseStorage: + return self.storages[name] + + def create_new_study( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + directions: List[str], + study_name: Optional[str] = None, + ) -> int: + return self.get_storage(storage_name).create_new_study( + directions=[StudyDirection[direction] for direction in directions], + study_name=study_name, + ) + + def delete_study( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + ) -> None: + return self.get_storage(storage_name).delete_study(study_id=study_id) + + def set_study_user_attr( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + key: str, + value: Any, + ) -> None: + return self.get_storage(storage_name).set_study_user_attr( + study_id=study_id, key=key, value=loads(value) # type: ignore[no-untyped-call] + ) + + def set_study_system_attr( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + key: str, + value: Any, + ) -> None: + return self.get_storage(storage_name).set_study_system_attr( + study_id=study_id, + key=key, + value=loads(value), # type: ignore[no-untyped-call] + ) + + def get_study_id_from_name( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_name: str, + ) -> int: + return self.get_storage(storage_name).get_study_id_from_name(study_name=study_name) + + def get_study_name_from_id( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + ) -> str: + return self.get_storage(storage_name).get_study_name_from_id(study_id=study_id) + + def get_study_directions( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + ) -> List[str]: + directions = self.get_storage(storage_name).get_study_directions(study_id=study_id) + return [direction.name for direction in directions] + + def get_study_user_attrs( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + ) -> Dict[str, Any]: + return dumps( + self.get_storage(storage_name).get_study_user_attrs( # type: ignore[no-untyped-call] + study_id=study_id + ) + ) + + def get_study_system_attrs( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + ) -> Dict[str, Any]: + return dumps( + self.get_storage(storage_name).get_study_system_attrs( # type: ignore[no-untyped-call] + study_id=study_id + ) + ) + + def get_all_studies(self, comm: "distributed.comm.tcp.TCP", storage_name: str) -> List[dict]: + studies = self.get_storage(storage_name).get_all_studies() + return [_serialize_frozenstudy(s) for s in studies] + + def create_new_trial( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + template_trial: Optional[dict] = None, + ) -> int: + deserialized_template_trial = None + if template_trial is not None: + deserialized_template_trial = _deserialize_frozentrial(template_trial) + return self.get_storage(storage_name).create_new_trial( + study_id=study_id, + template_trial=deserialized_template_trial, + ) + + def set_trial_param( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + param_name: str, + param_value_internal: float, + distribution: str, + ) -> None: + return self.get_storage(storage_name).set_trial_param( + trial_id=trial_id, + param_name=param_name, + param_value_internal=param_value_internal, + distribution=json_to_distribution(distribution), + ) + + def get_trial_id_from_study_id_trial_number( + self, comm: "distributed.comm.tcp.TCP", storage_name: str, study_id: int, trial_number: int + ) -> int: + return self.get_storage(storage_name).get_trial_id_from_study_id_trial_number( + study_id=study_id, + trial_number=trial_number, + ) + + def get_trial_number_from_id( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + ) -> int: + return self.get_storage(storage_name).get_trial_number_from_id(trial_id=trial_id) + + def get_trial_param( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + param_name: str, + ) -> float: + return self.get_storage(storage_name).get_trial_param( + trial_id=trial_id, + param_name=param_name, + ) + + def set_trial_state_values( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + state: str, + values: Optional[Sequence[float]] = None, + ) -> bool: + return self.get_storage(storage_name).set_trial_state_values( + trial_id=trial_id, + state=TrialState[state], + values=values, + ) + + def set_trial_intermediate_value( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + step: int, + intermediate_value: float, + ) -> None: + return self.get_storage(storage_name).set_trial_intermediate_value( + trial_id=trial_id, + step=step, + intermediate_value=intermediate_value, + ) + + def set_trial_user_attr( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + key: str, + value: Any, + ) -> None: + return self.get_storage(storage_name).set_trial_user_attr( + trial_id=trial_id, + key=key, + value=loads(value), # type: ignore[no-untyped-call] + ) + + def set_trial_system_attr( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + key: str, + value: JSONSerializable, + ) -> None: + return self.get_storage(storage_name).set_trial_system_attr( + trial_id=trial_id, + key=key, + value=loads(value), # type: ignore[no-untyped-call] + ) + + def get_trial( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + trial_id: int, + ) -> dict: + trial = self.get_storage(storage_name).get_trial(trial_id=trial_id) + return _serialize_frozentrial(trial) + + def get_all_trials( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + deepcopy: bool = True, + states: Optional[Tuple[str, ...]] = None, + ) -> List[dict]: + deserialized_states = None + if states is not None: + deserialized_states = tuple(TrialState[s] for s in states) + trials = self.get_storage(storage_name).get_all_trials( + study_id=study_id, + deepcopy=deepcopy, + states=deserialized_states, + ) + return [_serialize_frozentrial(t) for t in trials] + + def get_n_trials( + self, + comm: "distributed.comm.tcp.TCP", + storage_name: str, + study_id: int, + state: Optional[Union[Tuple[str, ...], str]] = None, + ) -> int: + deserialized_state: Optional[Union[Tuple[TrialState, ...], TrialState]] = None + if state is not None: + if isinstance(state, str): + deserialized_state = TrialState[state] + else: + deserialized_state = tuple(TrialState[s] for s in state) + return self.get_storage(storage_name).get_n_trials( + study_id=study_id, + state=deserialized_state, + ) + + +def _register_with_scheduler( + dask_scheduler: "distributed.Scheduler", storage: Union[None, str, BaseStorage], name: str +) -> None: + if "optuna" not in dask_scheduler.extensions: + ext = _OptunaSchedulerExtension(dask_scheduler) + else: + ext = dask_scheduler.extensions["optuna"] + + if name not in ext.storages: + ext.storages[name] = optuna.storages.get_storage(storage) + + +@experimental_class("3.1.0") +class DaskStorage(BaseStorage): + """Dask-compatible storage class. + + This storage class wraps a Optuna storage class (e.g. Optuna’s in-memory or sqlite storage) + and is used to run optimization trials in parallel on a Dask cluster. + The underlying Optuna storage object lives on the cluster’s scheduler and any method calls on + the :obj:`DaskStorage` instance results in the same method being called on the underlying + Optuna storage object. + + See `this example `_ or the following YouTube video + for how to use :obj:`DaskStorage` to extend Optuna's in-memory storage class to run across + multiple processes. + + .. raw:: html + + +
+
+ + Args: + storage: + Optuna storage url to use for underlying Optuna storage class to wrap + (e.g. :obj:`None` for in-memory storage, ``sqlite:///example.db`` + for SQLite storage). Defaults to :obj:`None`. + + name: + Unique identifier for the Dask storage class. Specifying a custom name can sometimes + be useful for logging or debugging. If :obj:`None` is provided, + a random name will be automatically generated. + + client: + Dask ``Client`` to connect to. If not provided, will attempt to find an + existing ``Client``. + + register: + Whether or not to register this storage instance with the cluster scheduler. + Most common usage of this storage class will not need to specify this argument. + Defaults to ``True``. + + """ + + def __init__( + self, + storage: Union[None, str, BaseStorage] = None, + name: Optional[str] = None, + client: Optional["distributed.Client"] = None, + register: bool = True, + ): + _imports.check() + self.name = name or f"dask-storage-{uuid.uuid4().hex}" + self._client = client + if register: + if self.client.asynchronous or getattr(thread_state, "on_event_loop_thread", False): + + async def _register() -> DaskStorage: + await self.client.run_on_scheduler( # type: ignore[no-untyped-call] + _register_with_scheduler, storage=storage, name=self.name + ) + return self + + self._started = asyncio.ensure_future(_register()) + else: + self.client.run_on_scheduler( # type: ignore[no-untyped-call] + _register_with_scheduler, storage=storage, name=self.name + ) + + @property + def client(self) -> "distributed.Client": + if not self._client: + self._client = get_client() + return self._client + + def __await__(self) -> Generator[Any, None, "DaskStorage"]: + if hasattr(self, "_started"): + return self._started.__await__() + else: + + async def _() -> DaskStorage: + return self + + return _().__await__() + + def __reduce__(self) -> tuple: + # We don't have a reference to underlying Optuna storage instance which lives + # on the scheduler. This is okay since this DaskStorage instance has already been + # registered with the scheduler, and ``storage`` is only ever needed during the + # scheduler registration process. We use ``storage=None`` below by convention. + return (DaskStorage, (None, self.name, None, False)) + + def get_base_storage(self) -> BaseStorage: + """Retrieve underlying Optuna storage instance from the scheduler. + + This is a convenience method to extract the Optuna storage instance stored on + the Dask scheduler process to the local Python process. + """ + + def _get_base_storage(dask_scheduler: distributed.Scheduler, name: str) -> BaseStorage: + return dask_scheduler.extensions["optuna"].storages[name] + + return self.client.run_on_scheduler( # type: ignore[no-untyped-call] + _get_base_storage, name=self.name + ) + + def create_new_study( + self, directions: Sequence[StudyDirection], study_name: Optional[str] = None + ) -> int: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_create_new_study, # type: ignore[union-attr] + storage_name=self.name, + study_name=study_name, + directions=[direction.name for direction in directions], + ) + + def delete_study(self, study_id: int) -> None: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_delete_study, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + ) + + def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_set_study_user_attr, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + key=key, + value=dumps(value), # type: ignore[no-untyped-call] + ) + + def set_study_system_attr(self, study_id: int, key: str, value: Any) -> None: + return self.client.sync( + self.client.scheduler.optuna_set_study_system_attr, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + key=key, + value=dumps(value), # type: ignore[no-untyped-call] + ) + + # Basic study access + + def get_study_id_from_name(self, study_name: str) -> int: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_study_id_from_name, # type: ignore[union-attr] + study_name=study_name, + storage_name=self.name, + ) + + def get_study_name_from_id(self, study_id: int) -> str: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_study_name_from_id, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + ) + + def get_study_directions(self, study_id: int) -> List[StudyDirection]: + directions = self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_study_directions, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + ) + return [StudyDirection[direction] for direction in directions] + + def get_study_user_attrs(self, study_id: int) -> Dict[str, Any]: + return loads( # type: ignore[no-untyped-call] + self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_study_user_attrs, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + ) + ) + + def get_study_system_attrs(self, study_id: int) -> Dict[str, Any]: + return loads( # type: ignore[no-untyped-call] + self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_study_system_attrs, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + ) + ) + + def get_all_studies(self) -> List[FrozenStudy]: + results = self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_all_studies, # type: ignore[union-attr] + storage_name=self.name, + ) + return [_deserialize_frozenstudy(i) for i in results] + + # Basic trial manipulation + + def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial] = None) -> int: + serialized_template_trial = None + if template_trial is not None: + serialized_template_trial = _serialize_frozentrial(template_trial) + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_create_new_trial, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + template_trial=serialized_template_trial, + ) + + def set_trial_param( + self, + trial_id: int, + param_name: str, + param_value_internal: float, + distribution: BaseDistribution, + ) -> None: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_set_trial_param, # type: ignore[union-attr] + storage_name=self.name, + trial_id=trial_id, + param_name=param_name, + param_value_internal=param_value_internal, + distribution=distribution_to_json(distribution), + ) + + def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_trial_id_from_study_id_trial_number, # type: ignore[union-attr] # NOQA: E501 + storage_name=self.name, + study_id=study_id, + trial_number=trial_number, + ) + + def get_trial_number_from_id(self, trial_id: int) -> int: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_trial_number_from_id, # type: ignore[union-attr] + storage_name=self.name, + trial_id=trial_id, + ) + + def get_trial_param(self, trial_id: int, param_name: str) -> float: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_trial_param, # type: ignore[union-attr] + storage_name=self.name, + trial_id=trial_id, + param_name=param_name, + ) + + def set_trial_state_values( + self, trial_id: int, state: TrialState, values: Optional[Sequence[float]] = None + ) -> bool: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_set_trial_state_values, # type: ignore[union-attr] + storage_name=self.name, + trial_id=trial_id, + state=state.name, + values=values, + ) + + def set_trial_intermediate_value( + self, trial_id: int, step: int, intermediate_value: float + ) -> None: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_set_trial_intermediate_value, # type: ignore[union-attr] + storage_name=self.name, + trial_id=trial_id, + step=step, + intermediate_value=intermediate_value, + ) + + def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_set_trial_user_attr, # type: ignore[union-attr] + storage_name=self.name, + trial_id=trial_id, + key=key, + value=dumps(value), # type: ignore[no-untyped-call] + ) + + def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None: + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_set_trial_system_attr, # type: ignore[union-attr] + storage_name=self.name, + trial_id=trial_id, + key=key, + value=dumps(value), # type: ignore[no-untyped-call] + ) + + # Basic trial access + + async def _get_trial(self, trial_id: int) -> FrozenTrial: + serialized_trial = await self.client.scheduler.optuna_get_trial( # type: ignore[union-attr] # NOQA: E501 + trial_id=trial_id, storage_name=self.name + ) + return _deserialize_frozentrial(serialized_trial) + + def get_trial(self, trial_id: int) -> FrozenTrial: + return self.client.sync( # type: ignore[no-untyped-call] + self._get_trial, trial_id=trial_id + ) + + async def _get_all_trials( + self, study_id: int, deepcopy: bool = True, states: Optional[Iterable[TrialState]] = None + ) -> List[FrozenTrial]: + serialized_states = None + if states is not None: + serialized_states = tuple(s.name for s in states) + serialized_trials = await self.client.scheduler.optuna_get_all_trials( # type: ignore[union-attr] # NOQA: E501 + storage_name=self.name, + study_id=study_id, + deepcopy=deepcopy, + states=serialized_states, + ) + return [_deserialize_frozentrial(t) for t in serialized_trials] + + def get_all_trials( + self, study_id: int, deepcopy: bool = True, states: Optional[Container[TrialState]] = None + ) -> List[FrozenTrial]: + return self.client.sync( # type: ignore[no-untyped-call] + self._get_all_trials, + study_id=study_id, + deepcopy=deepcopy, + states=states, + ) + + def get_n_trials( + self, study_id: int, state: Optional[Union[Tuple[TrialState, ...], TrialState]] = None + ) -> int: + serialized_state: Optional[Union[Tuple[str, ...], str]] = None + if state is not None: + if isinstance(state, TrialState): + serialized_state = state.name + else: + serialized_state = tuple(s.name for s in state) + return self.client.sync( # type: ignore[no-untyped-call] + self.client.scheduler.optuna_get_n_trials, # type: ignore[union-attr] + storage_name=self.name, + study_id=study_id, + state=serialized_state, + ) From 9660092358bf9ee05d30bef0f9518316a4c333b2 Mon Sep 17 00:00:00 2001 From: gen740 Date: Fri, 2 Feb 2024 14:16:27 +0900 Subject: [PATCH 05/11] Fetched from https://github.com/optuna/optuna/commit/ff991f1e31fe9d62793e6e328cc9a5f9e7ca0df4 --- tests/test_dask.py | 148 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/test_dask.py diff --git a/tests/test_dask.py b/tests/test_dask.py new file mode 100644 index 00000000..8063c4a6 --- /dev/null +++ b/tests/test_dask.py @@ -0,0 +1,148 @@ +from contextlib import contextmanager +import time +from typing import Iterator + +import numpy as np +import pytest + +import optuna +from optuna._imports import try_import +from optuna.integration.dask import _OptunaSchedulerExtension +from optuna.integration.dask import DaskStorage +from optuna.testing.tempfile_pool import NamedTemporaryFilePool +from optuna.trial import Trial + + +with try_import() as _imports: + from distributed import Client + from distributed import Scheduler + from distributed import wait + from distributed import Worker + from distributed.utils_test import clean + from distributed.utils_test import gen_cluster + +pytestmark = pytest.mark.integration + + +STORAGE_MODES = ["inmemory", "sqlite"] + + +@contextmanager +def get_storage_url(specifier: str) -> Iterator: + tmpfile = None + try: + if specifier == "inmemory": + url = None + elif specifier == "sqlite": + tmpfile = NamedTemporaryFilePool().tempfile() + url = "sqlite:///{}".format(tmpfile.name) + else: + raise ValueError( + "Invalid specifier entered. Was expecting 'inmemory' or 'sqlite'" + f"but got {specifier} instead" + ) + yield url + finally: + if tmpfile is not None: + tmpfile.close() + + +def objective(trial: Trial) -> float: + x = trial.suggest_float("x", -10, 10) + return (x - 2) ** 2 + + +def objective_slow(trial: Trial) -> float: + time.sleep(2) + return objective(trial) + + +@pytest.fixture +def client() -> "Client": # type: ignore[misc] + with clean(): + with Client(dashboard_address=":0") as client: # type: ignore[no-untyped-call] + yield client + + +def test_experimental(client: "Client") -> None: + with pytest.warns(optuna.exceptions.ExperimentalWarning): + DaskStorage() + + +def test_no_client_informative_error() -> None: + with pytest.raises(ValueError, match="No global client found"): + DaskStorage() + + +def test_name_unique(client: "Client") -> None: + s1 = DaskStorage() + s2 = DaskStorage() + assert s1.name != s2.name + + +@pytest.mark.parametrize("storage_specifier", STORAGE_MODES) +def test_study_optimize(client: "Client", storage_specifier: str) -> None: + with get_storage_url(storage_specifier) as url: + storage = DaskStorage(storage=url) + study = optuna.create_study(storage=storage) + assert not study.trials + futures = [ + client.submit( # type: ignore[no-untyped-call] + study.optimize, objective, n_trials=1, pure=False + ) + for _ in range(10) + ] + wait(futures) # type: ignore[no-untyped-call] + assert len(study.trials) == 10 + + +@pytest.mark.parametrize("storage_specifier", STORAGE_MODES) +def test_get_base_storage(client: "Client", storage_specifier: str) -> None: + with get_storage_url(storage_specifier) as url: + dask_storage = DaskStorage(url) + storage = dask_storage.get_base_storage() + expected_type = type(optuna.storages.get_storage(url)) + assert type(storage) is expected_type + + +@pytest.mark.parametrize("direction", ["maximize", "minimize"]) +def test_study_direction_best_value(client: "Client", direction: str) -> None: + # Regression test for https://github.com/jrbourbeau/dask-optuna/issues/15 + pytest.importorskip("pandas") + storage = DaskStorage() + study = optuna.create_study(storage=storage, direction=direction) + f = client.submit(study.optimize, objective, n_trials=10) # type: ignore[no-untyped-call] + wait(f) # type: ignore[no-untyped-call] + + # Ensure that study.best_value matches up with the expected value from + # the trials DataFrame + trials_value = study.trials_dataframe()["value"] + if direction == "maximize": + expected = trials_value.max() + else: + expected = trials_value.min() + + np.testing.assert_allclose(expected, study.best_value) + + +if _imports.is_successful(): + + @gen_cluster(client=True) + async def test_daskstorage_registers_extension( + c: "Client", s: "Scheduler", a: "Worker", b: "Worker" + ) -> None: + assert "optuna" not in s.extensions + await DaskStorage() + assert "optuna" in s.extensions + assert type(s.extensions["optuna"]) is _OptunaSchedulerExtension + + @gen_cluster(client=True) + async def test_name(c: "Client", s: "Scheduler", a: "Worker", b: "Worker") -> None: + await DaskStorage(name="foo") + ext = s.extensions["optuna"] + assert len(ext.storages) == 1 + assert type(ext.storages["foo"]) is optuna.storages.InMemoryStorage + + await DaskStorage(name="bar") + assert len(ext.storages) == 2 + assert type(ext.storages["bar"]) is optuna.storages.InMemoryStorage From 77419d537f7d07af6c5b7ed38366357c1c844e64 Mon Sep 17 00:00:00 2001 From: gen740 Date: Fri, 2 Feb 2024 14:36:48 +0900 Subject: [PATCH 06/11] Apply changes --- README.md | 1 + docs/source/reference/index.rst | 9 +++++++++ optuna_integration/dask.py | 2 +- pyproject.toml | 1 + tests/test_dask.py | 6 +++--- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d71517a5..4d8f0281 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc * [CatBoost](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catboost) ([example](https://github.com/optuna/optuna-examples/blob/main/catboost/catboost_pruning.py)) * [Chainer](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainer) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainer_integration.py)) * [ChainerMN](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainermn) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainermn_simple.py)) +* [Dask](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#dask) ([example](https://github.com/optuna/optuna-examples/tree/main/dask/dask_simple.py)) * FastAI ([V1](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv1) ([example](https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv1_simple.py)), ([V2](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv2) ([example]https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv2_simple.py))) * [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras)) * [MXNet](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#mxnet) ([example](https://github.com/optuna/optuna-examples/tree/main/mxnet)) diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 01013e8a..128d5bd9 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -55,6 +55,15 @@ Chainer optuna.integration.ChainerPruningExtension optuna.integration.ChainerMNStudy +Dask +---- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + optuna.integration.DaskStorage + fast.ai ------- diff --git a/optuna_integration/dask.py b/optuna_integration/dask.py index e6367b2f..44a01925 100644 --- a/optuna_integration/dask.py +++ b/optuna_integration/dask.py @@ -14,7 +14,7 @@ import optuna from optuna._experimental import experimental_class -from optuna._imports import try_import +from optuna_integration._imports import try_import from optuna._typing import JSONSerializable from optuna.distributions import BaseDistribution from optuna.distributions import distribution_to_json diff --git a/pyproject.toml b/pyproject.toml index d31d46e5..d22fed52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ all = [ "catalyst", "catboost>=0.26; sys_platform!='darwin'", "catboost>=0.26,<1.2; sys_platform=='darwin'", + "distributed", "fastai", "mxnet", "shap", diff --git a/tests/test_dask.py b/tests/test_dask.py index 8063c4a6..590a7800 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -6,9 +6,9 @@ import pytest import optuna -from optuna._imports import try_import -from optuna.integration.dask import _OptunaSchedulerExtension -from optuna.integration.dask import DaskStorage +from optuna_integration._imports import try_import +from optuna_integration.dask import _OptunaSchedulerExtension +from optuna_integration.dask import DaskStorage from optuna.testing.tempfile_pool import NamedTemporaryFilePool from optuna.trial import Trial From bfad071ca7cc7bbd9bd55fa38d3ad838ba367b2e Mon Sep 17 00:00:00 2001 From: gen740 Date: Fri, 2 Feb 2024 15:12:50 +0900 Subject: [PATCH 07/11] Remove unused pytest mark --- tests/test_dask.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_dask.py b/tests/test_dask.py index 590a7800..cf32e615 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -21,8 +21,6 @@ from distributed.utils_test import clean from distributed.utils_test import gen_cluster -pytestmark = pytest.mark.integration - STORAGE_MODES = ["inmemory", "sqlite"] From fcb0ff82faa64fca80090f8224cb6adcc3e9acbd Mon Sep 17 00:00:00 2001 From: gen740 Date: Fri, 2 Feb 2024 15:42:19 +0900 Subject: [PATCH 08/11] Fix unused (and unintended) import --- tests/importance_tests/test_init.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/importance_tests/test_init.py b/tests/importance_tests/test_init.py index 865b8953..1b1d0dd1 100644 --- a/tests/importance_tests/test_init.py +++ b/tests/importance_tests/test_init.py @@ -4,7 +4,6 @@ import optuna from optuna import samplers from optuna.importance import get_param_importances -import optuna.integration.shap from optuna.samplers import RandomSampler from optuna.study import create_study from optuna.testing.objectives import pruned_objective From dd7a78483a49282e77236477f0b3d9ba70f93d18 Mon Sep 17 00:00:00 2001 From: gen740 Date: Fri, 2 Feb 2024 15:49:00 +0900 Subject: [PATCH 09/11] Also delete duplicated import --- tests/importance_tests/test_init.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/importance_tests/test_init.py b/tests/importance_tests/test_init.py index 1b1d0dd1..69edcf8e 100644 --- a/tests/importance_tests/test_init.py +++ b/tests/importance_tests/test_init.py @@ -2,7 +2,6 @@ import numpy as np import optuna -from optuna import samplers from optuna.importance import get_param_importances from optuna.samplers import RandomSampler from optuna.study import create_study @@ -61,7 +60,7 @@ def objective(trial: Trial) -> float: return value with StorageSupplier(storage_mode) as storage: - study = create_study(storage=storage, sampler=samplers.RandomSampler()) + study = create_study(storage=storage, sampler=RandomSampler()) study.optimize(objective, n_trials=3) param_importance = get_param_importances( From 7b8d965d356724f63a2e4cd7fca55a1e04ec0c4e Mon Sep 17 00:00:00 2001 From: gen740 Date: Fri, 2 Feb 2024 15:56:36 +0900 Subject: [PATCH 10/11] Apply isort --- optuna_integration/dask.py | 3 ++- tests/test_dask.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/optuna_integration/dask.py b/optuna_integration/dask.py index 44a01925..777c12e6 100644 --- a/optuna_integration/dask.py +++ b/optuna_integration/dask.py @@ -14,7 +14,6 @@ import optuna from optuna._experimental import experimental_class -from optuna_integration._imports import try_import from optuna._typing import JSONSerializable from optuna.distributions import BaseDistribution from optuna.distributions import distribution_to_json @@ -25,6 +24,8 @@ from optuna.trial import FrozenTrial from optuna.trial import TrialState +from optuna_integration._imports import try_import + with try_import() as _imports: import distributed diff --git a/tests/test_dask.py b/tests/test_dask.py index cf32e615..9acd6382 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -3,14 +3,14 @@ from typing import Iterator import numpy as np +import optuna +from optuna.testing.tempfile_pool import NamedTemporaryFilePool +from optuna.trial import Trial import pytest -import optuna from optuna_integration._imports import try_import from optuna_integration.dask import _OptunaSchedulerExtension from optuna_integration.dask import DaskStorage -from optuna.testing.tempfile_pool import NamedTemporaryFilePool -from optuna.trial import Trial with try_import() as _imports: From 29e74be65cd70700b8896c8f5de1e2989e6bd9fd Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 1 Feb 2024 20:37:44 +0900 Subject: [PATCH 11/11] Add sklearn integration and apply related changes. --- README.md | 1 + docs/source/reference/index.rst | 49 ++++++++------ optuna_integration/__init__.py | 110 ++++++++++++++++++++++++++++++++ optuna_integration/shap.py | 3 +- optuna_integration/sklearn.py | 57 +++++++++++------ pyproject.toml | 6 ++ tests/test_sklearn.py | 14 ++-- 7 files changed, 192 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index d71517a5..6b3b200e 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc * [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras)) * [MXNet](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#mxnet) ([example](https://github.com/optuna/optuna-examples/tree/main/mxnet)) * [SHAP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#shap) +* [sklearn](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#sklearn) ([example](https://github.com/optuna/optuna-examples/tree/main/sklearn/sklearn_optuna_search_cv_simple.py)) * [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py)) * [TensorBoard](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorboard) ([example](https://github.com/optuna/optuna-examples/tree/main/tensorboard/tensorboard_simple.py)) * [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py)) diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 01013e8a..94a62662 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -10,11 +10,11 @@ The former is provided for backward compatibility. For most of the ML frameworks supported by Optuna, the corresponding Optuna integration class serves only to implement a callback object and functions, compliant with the framework's specific callback API, to be called with each intermediate step in the model training. The functionality implemented in these callbacks across the different ML frameworks includes: -(1) Reporting intermediate model scores back to the Optuna trial using :func:`optuna.trial.Trial.report`, -(2) According to the results of :func:`optuna.trial.Trial.should_prune`, pruning the current model by raising :func:`optuna.TrialPruned`, and -(3) Reporting intermediate Optuna data such as the current trial number back to the framework, as done in :class:`~optuna.integration.MLflowCallback`. +(1) Reporting intermediate model scores back to the Optuna trial using `optuna.trial.Trial.report `_, +(2) According to the results of `optuna.trial.Trial.should_prune `_, pruning the current model by raising `optuna.TrialPruned `_, and +(3) Reporting intermediate Optuna data such as the current trial number back to the framework, as done in :class:`~optuna_integration.MLflowCallback`. -For scikit-learn, an integrated :class:`~optuna.integration.OptunaSearchCV` estimator is available that combines scikit-learn BaseEstimator functionality with access to a class-level ``Study`` object. +For scikit-learn, an integrated :class:`~optuna_integration.OptunaSearchCV` estimator is available that combines scikit-learn BaseEstimator functionality with access to a class-level ``Study`` object. AllenNLP -------- @@ -23,9 +23,9 @@ AllenNLP :toctree: generated/ :nosignatures: - optuna.integration.AllenNLPExecutor - optuna.integration.allennlp.dump_best_config - optuna.integration.AllenNLPPruningCallback + optuna_integration.AllenNLPExecutor + optuna_integration.allennlp.dump_best_config + optuna_integration.AllenNLPPruningCallback Catalyst -------- @@ -34,7 +34,7 @@ Catalyst :toctree: generated/ :nosignatures: - optuna.integration.CatalystPruningCallback + optuna_integration.CatalystPruningCallback CatBoost -------- @@ -43,7 +43,7 @@ CatBoost :toctree: generated/ :nosignatures: - optuna.integration.CatBoostPruningCallback + optuna_integration.CatBoostPruningCallback Chainer ------- @@ -52,8 +52,8 @@ Chainer :toctree: generated/ :nosignatures: - optuna.integration.ChainerPruningExtension - optuna.integration.ChainerMNStudy + optuna_integration.ChainerPruningExtension + optuna_integration.ChainerMNStudy fast.ai ------- @@ -62,9 +62,9 @@ fast.ai :toctree: generated/ :nosignatures: - optuna.integration.FastAIV1PruningCallback - optuna.integration.FastAIV2PruningCallback - optuna.integration.FastAIPruningCallback + optuna_integration.FastAIV1PruningCallback + optuna_integration.FastAIV2PruningCallback + optuna_integration.FastAIPruningCallback Keras ----- @@ -73,7 +73,7 @@ Keras :toctree: generated/ :nosignatures: - optuna.integration.KerasPruningCallback + optuna_integration.KerasPruningCallback MXNet ----- @@ -82,7 +82,7 @@ MXNet :toctree: generated/ :nosignatures: - optuna.integration.MXNetPruningCallback + optuna_integration.MXNetPruningCallback SHAP ---- @@ -91,7 +91,16 @@ SHAP :toctree: generated/ :nosignatures: - optuna.integration.ShapleyImportanceEvaluator + optuna_integration.ShapleyImportanceEvaluator + +sklearn +------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + optuna_integration.OptunaSearchCV skorch ------ @@ -100,7 +109,7 @@ skorch :toctree: generated/ :nosignatures: - optuna.integration.SkorchPruningCallback + optuna_integration.SkorchPruningCallback TensorBoard ----------- @@ -109,7 +118,7 @@ TensorBoard :toctree: generated/ :nosignatures: - optuna.integration.TensorBoardCallback + optuna_integration.TensorBoardCallback TensorFlow ---------- @@ -118,4 +127,4 @@ TensorFlow :toctree: generated/ :nosignatures: - optuna.integration.TFKerasPruningCallback + optuna_integration.TFKerasPruningCallback diff --git a/optuna_integration/__init__.py b/optuna_integration/__init__.py index e69de29b..10d47ef6 100644 --- a/optuna_integration/__init__.py +++ b/optuna_integration/__init__.py @@ -0,0 +1,110 @@ +import os +import sys +from types import ModuleType +from typing import Any +from typing import TYPE_CHECKING + + +_import_structure = { + "allennlp": ["AllenNLPExecutor", "AllenNLPPruningCallback"], + "catalyst": ["CatalystPruningCallback"], + "catboost": ["CatBoostPruningCallback"], + "chainer": ["ChainerPruningExtension"], + "chainermn": ["ChainerMNStudy"], + "fastaiv1": ["FastAIV1PruningCallback"], + "fastaiv2": ["FastAIV2PruningCallback", "FastAIPruningCallback"], + "keras": ["KerasPruningCallback"], + "mxnet": ["MXNetPruningCallback"], + "shap": ["ShapleyImportanceEvaluator"], + "sklearn": ["OptunaSearchCV"], + "skorch": ["SkorchPruningCallback"], + "tensorboard": ["TensorBoardCallback"], + "tensorflow": ["TensorFlowPruningHook"], + "tfkeras": ["TFKerasPruningCallback"], +} + + +if TYPE_CHECKING: + from optuna_integration.allennlp import AllenNLPExecutor + from optuna_integration.allennlp import AllenNLPPruningCallback + from optuna_integration.catalyst import CatalystPruningCallback + from optuna_integration.catboost import CatBoostPruningCallback + from optuna_integration.chainer import ChainerPruningExtension + from optuna_integration.chainermn import ChainerMNStudy + from optuna_integration.fastaiv1 import FastAIV1PruningCallback + from optuna_integration.fastaiv2 import FastAIPruningCallback + from optuna_integration.fastaiv2 import FastAIV2PruningCallback + from optuna_integration.keras import KerasPruningCallback + from optuna_integration.mxnet import MXNetPruningCallback + from optuna_integration.shap import ShapleyImportanceEvaluator + from optuna_integration.sklearn import OptunaSearchCV + from optuna_integration.skorch import SkorchPruningCallback + from optuna_integration.tensorboard import TensorBoardCallback + from optuna_integration.tensorflow import TensorFlowPruningHook + from optuna_integration.tfkeras import TFKerasPruningCallback +else: + + class _IntegrationModule(ModuleType): + """Module class that implements `optuna_integration` package. + + This class applies lazy import under `optuna_integration`, where submodules are imported + when they are actually accessed. Otherwise, `import optuna` becomes much slower because it + imports all submodules and their dependencies (e.g., chainer, keras, lightgbm) all at once. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + _modules = set(_import_structure.keys()) + _class_to_module = {} + for key, values in _import_structure.items(): + for value in values: + _class_to_module[value] = key + + def __getattr__(self, name: str) -> Any: + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError("module {} has no attribute {}".format(self.__name__, name)) + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str) -> ModuleType: + import importlib + + try: + return importlib.import_module("." + module_name, self.__name__) + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Optuna's integration modules for third-party libraries have started " + "migrating from Optuna itself to a package called `optuna-integration`. " + "The module you are trying to use has already been migrated to " + "`optuna-integration`. Please install the package by running " + "`pip install optuna-integration`." + ) + + sys.modules[__name__] = _IntegrationModule(__name__) + +__all__ = [ + "AllenNLPExecutor", + "AllenNLPPruningCallback", + "CatalystPruningCallback", + "CatBoostPruningCallback", + "ChainerMNStudy", + "ChainerPruningExtension", + "FastAIPruningCallback", + "FastAIV1PruningCallback", + "FastAIV2PruningCallback", + "KerasPruningCallback", + "MXNetPruningCallback", + "OptunaSearchCV", + "ShapleyImportanceEvaluator", + "SkorchPruningCallback", + "TensorBoardCallback", + "TensorFlowPruningHook", + "TFKerasPruningCallback", +] diff --git a/optuna_integration/shap.py b/optuna_integration/shap.py index 327c97c6..d20d5d13 100644 --- a/optuna_integration/shap.py +++ b/optuna_integration/shap.py @@ -18,9 +18,8 @@ with try_import() as _imports: - from sklearn.ensemble import RandomForestRegressor - from shap import TreeExplainer + from sklearn.ensemble import RandomForestRegressor @experimental_class("3.0.0") diff --git a/optuna_integration/sklearn.py b/optuna_integration/sklearn.py index 5135ffb2..86ab2270 100644 --- a/optuna_integration/sklearn.py +++ b/optuna_integration/sklearn.py @@ -14,7 +14,6 @@ from typing import Union import numpy as np - from optuna import distributions from optuna import logging from optuna import samplers @@ -33,6 +32,7 @@ import pandas as pd import scipy as sp from scipy.sparse import spmatrix + import sklearn from sklearn.base import BaseEstimator from sklearn.base import clone @@ -46,6 +46,7 @@ from sklearn.utils.metaestimators import _safe_split from sklearn.utils.validation import check_is_fitted + if not _imports.is_successful(): BaseEstimator = object # NOQA @@ -216,18 +217,34 @@ def __call__(self, trial: Trial) -> float: if self.enable_pruning: scores = self._cross_validate_with_pruning(trial, estimator) else: + sklearn_version = sklearn.__version__.split(".") + sklearn_major_version = int(sklearn_version[0]) + sklearn_minor_version = int(sklearn_version[1]) try: - scores = cross_validate( - estimator, - self.X, - self.y, - cv=self.cv, - error_score=self.error_score, - fit_params=self.fit_params, - groups=self.groups, - return_train_score=self.return_train_score, - scoring=self.scoring, - ) + if sklearn_major_version == 1 and sklearn_minor_version >= 4: + scores = cross_validate( + estimator, + self.X, + self.y, + cv=self.cv, + error_score=self.error_score, + params=self.fit_params, + groups=self.groups, + return_train_score=self.return_train_score, + scoring=self.scoring, + ) + else: + scores = cross_validate( + estimator, + self.X, + self.y, + cv=self.cv, + error_score=self.error_score, + fit_params=self.fit_params, + groups=self.groups, + return_train_score=self.return_train_score, + scoring=self.scoring, + ) except ValueError: n_splits = self.cv.get_n_splits(self.X, self.y, self.groups) fit_time = np.array([np.nan] * n_splits) @@ -410,7 +427,7 @@ class OptunaSearchCV(BaseEstimator): .. note:: ``n_jobs`` allows parallelization using :obj:`threading` and may suffer from `Python's GIL `_. - It is recommended to use :ref:`process-based parallelization` + It is recommended to use `process-based optimization `_ if ``func`` is CPU bound. n_trials: @@ -474,8 +491,8 @@ class OptunaSearchCV(BaseEstimator): .. seealso:: - See the tutorial of :ref:`optuna_callback` for how to use and implement - callback functions. + See the tutorial of `Callback for Study.optimize `_ + for how to use and implement callback functions. Attributes: best_estimator_: @@ -503,6 +520,8 @@ class OptunaSearchCV(BaseEstimator): .. testcode:: import optuna + import optuna_integration + from sklearn.datasets import load_iris from sklearn.svm import SVC @@ -510,7 +529,7 @@ class OptunaSearchCV(BaseEstimator): param_distributions = { "C": optuna.distributions.FloatDistribution(1e-10, 1e10, log=True) } - optuna_search = optuna.integration.OptunaSearchCV(clf, param_distributions) + optuna_search = optuna_integration.OptunaSearchCV(clf, param_distributions) X, y = load_iris(return_X_y=True) optuna_search.fit(X, y) y_pred = optuna_search.predict(X) @@ -519,7 +538,7 @@ class OptunaSearchCV(BaseEstimator): By following the scikit-learn convention for scorers, the direction of optimization is ``maximize``. See https://scikit-learn.org/stable/modules/model_evaluation.html. For the minimization problem, please multiply ``-1``. - """ + """ # NOQA: E501 _required_parameters = ["estimator", "param_distributions"] @@ -570,8 +589,8 @@ def classes_(self) -> OneDimArrayLikeType: @property def cv_results_(self) -> dict[str, Any]: - """A dictionary mapping a metric name to a list of - Cross-Validation results of all trials.""" + """A dictionary mapping a metric name to a list of Cross-Validation results of all trials.""" # NOQA: E501 + cv_results_dict_in_list = [trial_.user_attrs for trial_ in self.trials_] if len(cv_results_dict_in_list) == 0: cv_results_list_in_dict = {} diff --git a/pyproject.toml b/pyproject.toml index d31d46e5..0800523c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ checking = [ "typing_extensions>=3.10.0.0", ] document = [ + "pandas", + "scikit-learn>=0.24.2", + "scipy>=1.9.2; python_version>='3.8'", "sphinx", "sphinx_rtd_theme", ] @@ -57,7 +60,10 @@ all = [ "catboost>=0.26,<1.2; sys_platform=='darwin'", "fastai", "mxnet", + "pandas", "shap", + "scikit-learn>=0.24.2", + "scipy>=1.9.2; python_version>='3.8'", "skorch", "tensorboard", "tensorflow", diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index d52bf898..603467ab 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -5,8 +5,14 @@ import warnings import numpy as np +from optuna import distributions +from optuna.samplers import BruteForceSampler +from optuna.study import create_study +from optuna.terminator.erroreval import _CROSS_VALIDATION_SCORES_KEY import pytest import scipy as sp + +import optuna_integration as integration from sklearn.datasets import make_blobs from sklearn.datasets import make_regression from sklearn.decomposition import PCA @@ -17,12 +23,6 @@ from sklearn.neighbors import KernelDensity from sklearn.tree import DecisionTreeRegressor -from optuna import distributions -from optuna import integration -from optuna.samplers import BruteForceSampler -from optuna.study import create_study -from optuna.terminator.erroreval import _CROSS_VALIDATION_SCORES_KEY - pytestmark = pytest.mark.integration @@ -414,7 +414,7 @@ def test_callbacks() -> None: @pytest.mark.filterwarnings("ignore::UserWarning") -@patch("optuna.integration.sklearn.cross_validate") +@patch("optuna_integration.sklearn.cross_validate") def test_terminator_cv_score_reporting(mock: MagicMock) -> None: scores = { "fit_time": np.array([2.01, 1.78, 3.22]),