Skip to content

Commit

Permalink
Address mizuno's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 24, 2024
1 parent d06a851 commit 72d661f
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions package/samplers/auto_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,23 @@ def _include_conditional_param(self, study: Study) -> bool:
return False

param_key = set(trials[0].params)
for t in trials:
if param_key != set(t.params):
return True

return False
return any(param_key != set(t.params) for t in trials)

def _determine_multi_objective_sampler(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> None:
) -> BaseSampler:
if isinstance(self._sampler, (NSGAIISampler, NSGAIIISampler)):
return
return self._sampler

seed = self._rng.rng.randint(MAXINT32)
if not isinstance(self._sampler, TPESampler):
self._sampler = TPESampler(
return TPESampler(
seed=seed,
multivariate=True,
warn_independent_sampling=False,
constraints_func=self._constraints_func,
constant_liar=True,
)
return

complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
complete_trials.sort(key=lambda trial: trial.datetime_complete)
Expand All @@ -128,13 +123,15 @@ def _determine_multi_objective_sampler(
else NSGAIIISampler
)
# Use NSGA-II/III if len(complete_trials) <= _N_COMPLETE_TRIALS_FOR_NSGA.
self._sampler = nsga_sampler_cls(constraints_func=self._constraints_func, seed=seed)
return nsga_sampler_cls(constraints_func=self._constraints_func, seed=seed)

return self._sampler

def _determine_single_objective_sampler(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> None:
) -> BaseSampler | None:
if isinstance(self._sampler, TPESampler):
return
return self._sampler

seed = self._rng.rng.randint(MAXINT32)
if (
Expand All @@ -145,42 +142,41 @@ def _determine_single_objective_sampler(
# NOTE(nabenabe): The statement above is always true for Trial#1.
# Use ``TPESampler`` if search space includes conditional or categorical parameters.
# TBD: group=True?
self._sampler = TPESampler(
return TPESampler(
seed=seed,
multivariate=True,
warn_independent_sampling=False,
constraints_func=self._constraints_func,
constant_liar=True,
)
return

complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
complete_trials.sort(key=lambda trial: trial.datetime_complete)
if len(complete_trials) < self._N_COMPLETE_TRIALS_FOR_CMAES:
# Use ``GPSampler`` if search space is numerical and
# len(complete_trials) < _N_COMPLETE_TRIALS_FOR_CMAES.
if not isinstance(self._sampler, GPSampler):
self._sampler = GPSampler(seed=seed)
return

if not isinstance(self._sampler, CmaEsSampler):
return GPSampler(seed=seed)
elif not isinstance(self._sampler, CmaEsSampler):
# Use ``CmaEsSampler`` if search space is numerical and
# len(complete_trials) > _N_COMPLETE_TRIALS_FOR_CMAES.
# Warm start CMA-ES with the first _N_COMPLETE_TRIALS_FOR_CMAES complete trials.
warm_start_trials = complete_trials[: self._N_COMPLETE_TRIALS_FOR_CMAES]
# NOTE(nabenabe): ``CmaEsSampler`` internally falls back to ``RandomSampler`` for
# 1D problems.
self._sampler = CmaEsSampler(
return CmaEsSampler(
seed=seed, source_trials=warm_start_trials, warn_independent_sampling=False
)

return self._sampler

def _determine_sampler(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> None:
if len(study.directions) == 1:
self._determine_single_objective_sampler(study, trial, search_space)
self._sampler = self._determine_single_objective_sampler(study, trial, search_space)
else:
self._determine_multi_objective_sampler(study, trial, search_space)
self._sampler = self._determine_multi_objective_sampler(study, trial, search_space)

def infer_relative_search_space(
self, study: Study, trial: FrozenTrial
Expand Down

0 comments on commit 72d661f

Please sign in to comment.