Skip to content

Commit

Permalink
Apply the feedback from the mob review
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 31, 2024
1 parent a5bc25e commit ea89dd1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
27 changes: 17 additions & 10 deletions package/samplers/auto_sampler/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,23 @@ def _determine_single_objective_sampler(
# len(complete_trials) < _N_COMPLETE_TRIALS_FOR_CMAES.
if not isinstance(self._sampler, GPSampler):
return GPSampler(seed=seed)
elif not isinstance(self._sampler, CmaEsSampler):
# Use ``CmaEsSampler`` if search space is numerical and
# len(complete_trials) > _N_COMPLETE_TRIALS_FOR_CMAES.
# Warm start CMA-ES with the first _N_COMPLETE_TRIALS_FOR_CMAES complete trials.
complete_trials.sort(key=lambda trial: trial.datetime_complete)
warm_start_trials = complete_trials[: self._N_COMPLETE_TRIALS_FOR_CMAES]
# NOTE(nabenabe): ``CmaEsSampler`` internally falls back to ``RandomSampler`` for
# 1D problems.
return CmaEsSampler(
seed=seed, source_trials=warm_start_trials, warn_independent_sampling=True
elif len(search_space) > 1:
if not isinstance(self._sampler, CmaEsSampler):
# Use ``CmaEsSampler`` if search space is numerical and
# len(complete_trials) > _N_COMPLETE_TRIALS_FOR_CMAES.
# Warm start CMA-ES with the first _N_COMPLETE_TRIALS_FOR_CMAES complete trials.
complete_trials.sort(key=lambda trial: trial.datetime_complete)
warm_start_trials = complete_trials[: self._N_COMPLETE_TRIALS_FOR_CMAES]
return CmaEsSampler(
seed=seed, source_trials=warm_start_trials, warn_independent_sampling=True
)
else:
return TPESampler(
seed=seed,
multivariate=True,
warn_independent_sampling=False,
constraints_func=self._constraints_func,
constant_liar=True,
)

return self._sampler # No update happens to self._sampler.
Expand Down
20 changes: 20 additions & 0 deletions package/samplers/auto_sampler/tests/test_auto_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
parametrize_constraints = pytest.mark.parametrize("use_constraint", [True, False])


def objective_1d(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -5, 5)
return x**2


def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -5, 5)
y = trial.suggest_int("y", -5, 5)
Expand Down Expand Up @@ -105,6 +110,21 @@ def test_choose_cmaes() -> None:
] * n_trials_of_cmaes == sampler_names


def test_choose_tpe_for_1d() -> None:
# This test must be performed with a numerical objective function.
# For 1d problems, TPESampler will be chosen instead of CmaEsSampler.
n_trials_of_tpe = 100
n_trials_before_tpe = 20
auto_sampler = AutoSampler()
auto_sampler._N_COMPLETE_TRIALS_FOR_CMAES = n_trials_before_tpe
study = optuna.create_study(sampler=auto_sampler)
study.optimize(objective_1d, n_trials=n_trials_of_tpe + n_trials_before_tpe)
sampler_names = _get_used_sampler_names(study)
assert ["RandomSampler"] + ["GPSampler"] * (n_trials_before_tpe - 1) + [
"TPESampler"
] * n_trials_of_tpe == sampler_names


def test_choose_tpe_in_single_with_constraints() -> None:
n_trials = 30
auto_sampler = AutoSampler(constraints_func=constraints_func)
Expand Down

0 comments on commit ea89dd1

Please sign in to comment.