Skip to content

Commit

Permalink
Merge pull request optuna#1 from toshihikoyanase/prototype-auto-sampl…
Browse files Browse the repository at this point in the history
…er-with-fallback

Support multi-objective and constrained optimization.
  • Loading branch information
nabenabe0928 authored Oct 18, 2024
2 parents 6df7a95 + f71caeb commit fe3da2e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
4 changes: 3 additions & 1 deletion package/samplers/auto_sampler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ This sampler currently accepts only `seed`.

## Installation

This sampler requires optional dependencies of Optuna.

```shell
$ pip install scipy torch cmaes
$ pip install "optuna[optional]"
```

## Example
Expand Down
50 changes: 43 additions & 7 deletions package/samplers/auto_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from collections.abc import Sequence
from typing import Any
from typing import Callable
from typing import TYPE_CHECKING

from optuna.distributions import CategoricalDistribution
from optuna.samplers import BaseSampler
from optuna.samplers import CmaEsSampler
from optuna.samplers import GPSampler
from optuna.samplers import NSGAIISampler
from optuna.samplers import RandomSampler
from optuna.samplers import TPESampler
from optuna.samplers._lazy_random_state import LazyRandomState
Expand All @@ -31,30 +33,50 @@ class AutoSampler(BaseSampler):
.. testcode::
import optuna
from optuna.samplers import AutoSampler
import optunahub
def objective(trial):
x = trial.suggest_float("x", -5, 5)
return x**2
study = optuna.create_study(sampler=AutoSampler())
study.optimize(objective, n_trials=10)
module = optunahub.load_module("samplers/auto_sampler")
study = optuna.create_study(sampler=module.AutoSampler())
study.optimize(objective, n_trials=300)
.. note::
This sampler might require ``scipy``, ``torch``, and ``cmaes``.
You can install these dependencies with ``pip install scipy torch cmaes``.
This sampler requires optional dependencies of Optuna.
You can install them with ``pip install "optuna[optional]"``.
Alternatively, you can install them with ``pip install -r https://hub.optuna.org/samplers/auto_sampler/requirements.txt``.
Args:
seed: Seed for random number generator.
constraints_func:
An optional function that computes the objective constraints. It must take a
:class:`~optuna.trial.FrozenTrial` and return the constraints. The return value must
be a sequence of :obj:`float` s. A value strictly larger than 0 means that a
constraints is violated. A value equal to or smaller than 0 is considered feasible.
If ``constraints_func`` returns more than one value for a trial, that trial is
considered feasible if and only if all values are equal to 0 or smaller.
The ``constraints_func`` will be evaluated after each successful trial.
The function won't be called when trials fail or they are pruned, but this behavior is
subject to change in the future releases.
.. note::
If you enable this feature, Optuna's default sampler will be selected automatically.
"""

def __init__(self, seed: int | None = None) -> None:
def __init__(
self,
seed: int | None = None,
constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None,
) -> None:
self._rng = LazyRandomState(seed)
seed_for_random_sampler = self._rng.rng.randint(1 << 32)
self._sampler: BaseSampler = RandomSampler(seed=seed_for_random_sampler)
self._constraints_func = constraints_func

def reseed_rng(self) -> None:
self._rng.rng.seed()
Expand All @@ -75,10 +97,24 @@ def _include_conditional_param(self, study: Study) -> bool:
def _determine_sampler(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> None:
if len(study.directions) > 1:
# Fallback to the default sampler if the study has multiple objectives.
if isinstance(self._sampler, NSGAIISampler):
return
# TODO(toshihikoyanase): add warning message about fallback.
self._sampler = NSGAIISampler(constraints_func=self._constraints_func)
return

if isinstance(self._sampler, TPESampler):
return

seed = self._rng.rng.randint(1 << 32)
if self._constraints_func is not None:
# Fallback to the default sampler if the study has constraints.
# TODO(toshihikoyanase): add warning message about fallback.
self._sampler = TPESampler(seed=seed, constraines_func=self._constraints_func)
return

if any(
isinstance(d, CategoricalDistribution) for d in search_space.values()
) or self._include_conditional_param(study):
Expand Down

0 comments on commit fe3da2e

Please sign in to comment.