Skip to content

Commit

Permalink
Remove TPESampler for MO
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 25, 2024
1 parent 8ac08be commit 5b9c9a7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 57 deletions.
47 changes: 6 additions & 41 deletions package/samplers/auto_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
from optuna.samplers import BaseSampler
from optuna.samplers import CmaEsSampler
from optuna.samplers import GPSampler
from optuna.samplers import NSGAIIISampler
from optuna.samplers import NSGAIISampler
from optuna.samplers import RandomSampler
from optuna.samplers import TPESampler
from optuna.samplers._base import _process_constraints_after_trial
from optuna.samplers._lazy_random_state import LazyRandomState
from optuna.samplers._nsgaiii._sampler import _GENERATION_KEY as NSGA3_GENERATION_KEY
from optuna.samplers.nsgaii._sampler import _GENERATION_KEY as NSGA2_GENERATION_KEY
from optuna.search_space import IntersectionSearchSpace
from optuna.trial import TrialState

Expand All @@ -28,13 +25,11 @@


MAXINT32 = (1 << 31) - 1
THRESHOLD_OF_MANY_OBJECTIVES = 4
SAMPLER_KEY = "auto:sampler"


class AutoSampler(BaseSampler):
_N_COMPLETE_TRIALS_FOR_CMAES = 250
_N_COMPLETE_TRIALS_FOR_NSGA = 1000

"""Sampler automatically choosing an appropriate sampler based on search space.
Expand Down Expand Up @@ -102,31 +97,11 @@ def _include_conditional_param(self, study: Study) -> bool:
def _determine_multi_objective_sampler(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> BaseSampler:
if isinstance(self._sampler, (NSGAIISampler, NSGAIIISampler)):
if isinstance(self._sampler, NSGAIISampler):
return self._sampler

seed = self._rng.rng.randint(MAXINT32)
if not isinstance(self._sampler, TPESampler):
return TPESampler(
seed=seed,
multivariate=True,
warn_independent_sampling=False,
constraints_func=self._constraints_func,
constant_liar=True,
)

complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
complete_trials.sort(key=lambda trial: trial.datetime_complete)
if len(complete_trials) >= self._N_COMPLETE_TRIALS_FOR_NSGA:
nsga_sampler_cls = (
NSGAIISampler
if len(study.directions) < THRESHOLD_OF_MANY_OBJECTIVES
else NSGAIIISampler
)
# Use NSGA-II/III if len(complete_trials) <= _N_COMPLETE_TRIALS_FOR_NSGA.
return nsga_sampler_cls(constraints_func=self._constraints_func, seed=seed)

return self._sampler # No update happens to self._sampler.
return NSGAIISampler(constraints_func=self._constraints_func, seed=seed)

def _determine_single_objective_sampler(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
Expand Down Expand Up @@ -187,18 +162,6 @@ def infer_relative_search_space(
def sample_relative(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> dict[str, Any]:
n_objectives = len(study.directions)
if n_objectives > 1 and isinstance(self._sampler, TPESampler):
# NOTE(nabenabe): Set generation 0 so that NSGA-II/III can use the trial information
# obtained during the optimization using TPESampler.
# NOTE(nabenabe): Use NSGA-III for many objective problems.
_GENERATION_KEY = (
NSGA2_GENERATION_KEY
if n_objectives < THRESHOLD_OF_MANY_OBJECTIVES
else NSGA3_GENERATION_KEY
)
study._storage.set_trial_system_attr(trial._trial_id, _GENERATION_KEY, 0)

return self._sampler.sample_relative(study, trial, search_space)

def sample_independent(
Expand All @@ -211,8 +174,10 @@ def sample_independent(
return self._sampler.sample_independent(study, trial, param_name, param_distribution)

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
states_of_interest = [TrialState.COMPLETE, TrialState.WAITING]
if len(study._get_trials(deepcopy=False, states=states_of_interest)) != 0:
# NOTE(nabenabe): Sampler must be updated in this method. If, for example, it is updated in
# infer_relative_search_space, the sampler for before_trial and that for sample_relative,
# after_trial might be different, meaning that the sampling routine could be incompatible.
if len(study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,))) != 0:
search_space = IntersectionSearchSpace().calculate(study)
self._sampler = self._determine_sampler(study, trial, search_space)

Expand Down
24 changes: 8 additions & 16 deletions package/samplers/auto_sampler/tests/test_auto_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,33 +70,25 @@ def _check_constraints_of_all_trials(study: optuna.Study) -> None:


@parametrize_constraints
def test_choose_nsga3(use_constraint: bool) -> None:
n_trials_of_nsga = 100
n_trials_before_nsga = 100
def test_choose_for_many_objective(use_constraint: bool) -> None:
n_trials = 200
auto_sampler = AutoSampler(constraints_func=constraints_func if use_constraint else None)
auto_sampler._N_COMPLETE_TRIALS_FOR_NSGA = n_trials_before_nsga
study = optuna.create_study(sampler=auto_sampler, directions=["minimize"] * 4)
study.optimize(many_objective, n_trials=n_trials_before_nsga + n_trials_of_nsga)
study.optimize(many_objective, n_trials=n_trials)
sampler_names = _get_used_sampler_names(study)
assert ["RandomSampler"] + ["TPESampler"] * (n_trials_before_nsga - 1) + [
"NSGAIIISampler"
] * n_trials_of_nsga == sampler_names
assert ["RandomSampler"] + ["NSGAIISampler"] * (n_trials - 1) == sampler_names
if use_constraint:
_check_constraints_of_all_trials(study)


@parametrize_constraints
def test_choose_nsga2(use_constraint: bool) -> None:
n_trials_of_nsga = 100
n_trials_before_nsga = 100
def test_choose_for_multi_objective(use_constraint: bool) -> None:
n_trials = 200
auto_sampler = AutoSampler(constraints_func=constraints_func if use_constraint else None)
auto_sampler._N_COMPLETE_TRIALS_FOR_NSGA = n_trials_before_nsga
study = optuna.create_study(sampler=auto_sampler, directions=["minimize"] * 2)
study.optimize(multi_objective, n_trials=n_trials_before_nsga + n_trials_of_nsga)
study.optimize(multi_objective, n_trials=n_trials)
sampler_names = _get_used_sampler_names(study)
assert ["RandomSampler"] + ["TPESampler"] * (n_trials_before_nsga - 1) + [
"NSGAIISampler"
] * n_trials_of_nsga == sampler_names
assert ["RandomSampler"] + ["NSGAIISampler"] * (n_trials - 1) == sampler_names
if use_constraint:
_check_constraints_of_all_trials(study)

Expand Down

0 comments on commit 5b9c9a7

Please sign in to comment.