diff --git a/package/samplers/auto_sampler/sampler.py b/package/samplers/auto_sampler/sampler.py index cab1d7d2..29159b2a 100644 --- a/package/samplers/auto_sampler/sampler.py +++ b/package/samplers/auto_sampler/sampler.py @@ -96,28 +96,23 @@ def _include_conditional_param(self, study: Study) -> bool: return False param_key = set(trials[0].params) - for t in trials: - if param_key != set(t.params): - return True - - return False + return any(param_key != set(t.params) for t in trials) def _determine_multi_objective_sampler( self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution] - ) -> None: + ) -> BaseSampler: if isinstance(self._sampler, (NSGAIISampler, NSGAIIISampler)): - return + return self._sampler seed = self._rng.rng.randint(MAXINT32) if not isinstance(self._sampler, TPESampler): - self._sampler = TPESampler( + return TPESampler( seed=seed, multivariate=True, warn_independent_sampling=False, constraints_func=self._constraints_func, constant_liar=True, ) - return complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)) complete_trials.sort(key=lambda trial: trial.datetime_complete) @@ -128,13 +123,15 @@ def _determine_multi_objective_sampler( else NSGAIIISampler ) # Use NSGA-II/III if len(complete_trials) <= _N_COMPLETE_TRIALS_FOR_NSGA. - self._sampler = nsga_sampler_cls(constraints_func=self._constraints_func, seed=seed) + return nsga_sampler_cls(constraints_func=self._constraints_func, seed=seed) + + return self._sampler def _determine_single_objective_sampler( self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution] - ) -> None: + ) -> BaseSampler | None: if isinstance(self._sampler, TPESampler): - return + return self._sampler seed = self._rng.rng.randint(MAXINT32) if ( @@ -145,14 +142,13 @@ def _determine_single_objective_sampler( # NOTE(nabenabe): The statement above is always true for Trial#1. # Use ``TPESampler`` if search space includes conditional or categorical parameters. # TBD: group=True? - self._sampler = TPESampler( + return TPESampler( seed=seed, multivariate=True, warn_independent_sampling=False, constraints_func=self._constraints_func, constant_liar=True, ) - return complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)) complete_trials.sort(key=lambda trial: trial.datetime_complete) @@ -160,27 +156,27 @@ def _determine_single_objective_sampler( # Use ``GPSampler`` if search space is numerical and # len(complete_trials) < _N_COMPLETE_TRIALS_FOR_CMAES. if not isinstance(self._sampler, GPSampler): - self._sampler = GPSampler(seed=seed) - return - - if not isinstance(self._sampler, CmaEsSampler): + return GPSampler(seed=seed) + elif not isinstance(self._sampler, CmaEsSampler): # Use ``CmaEsSampler`` if search space is numerical and # len(complete_trials) > _N_COMPLETE_TRIALS_FOR_CMAES. # Warm start CMA-ES with the first _N_COMPLETE_TRIALS_FOR_CMAES complete trials. warm_start_trials = complete_trials[: self._N_COMPLETE_TRIALS_FOR_CMAES] # NOTE(nabenabe): ``CmaEsSampler`` internally falls back to ``RandomSampler`` for # 1D problems. - self._sampler = CmaEsSampler( + return CmaEsSampler( seed=seed, source_trials=warm_start_trials, warn_independent_sampling=False ) + return self._sampler + def _determine_sampler( self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution] ) -> None: if len(study.directions) == 1: - self._determine_single_objective_sampler(study, trial, search_space) + self._sampler = self._determine_single_objective_sampler(study, trial, search_space) else: - self._determine_multi_objective_sampler(study, trial, search_space) + self._sampler = self._determine_multi_objective_sampler(study, trial, search_space) def infer_relative_search_space( self, study: Study, trial: FrozenTrial