diff --git a/package/samplers/auto_sampler/sampler.py b/package/samplers/auto_sampler/sampler.py index b52a4eb8..c0ab815e 100644 --- a/package/samplers/auto_sampler/sampler.py +++ b/package/samplers/auto_sampler/sampler.py @@ -26,8 +26,8 @@ from optuna.trial import FrozenTrial -MAXINT32 = (1 << 31) - 1 -SAMPLER_KEY = "auto:sampler" +_MAXINT32 = (1 << 31) - 1 +_SAMPLER_KEY = "auto:sampler" _logger = get_logger(f"optuna.{__name__}") @@ -87,7 +87,7 @@ def __init__( constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None, ) -> None: self._rng = LazyRandomState(seed) - seed_for_random_sampler = self._rng.rng.randint(MAXINT32) + seed_for_random_sampler = self._rng.rng.randint(_MAXINT32) sampler: BaseSampler = RandomSampler(seed=seed_for_random_sampler) self._thread_local_sampler = _ThreadLocalSampler(sampler) self._constraints_func = constraints_func @@ -105,7 +105,9 @@ def reseed_rng(self) -> None: self._sampler.reseed_rng() def _include_conditional_param(self, study: Study) -> bool: - trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE, TrialState.PRUNED)) + trials = study._get_trials( + deepcopy=False, states=(TrialState.COMPLETE, TrialState.PRUNED), use_cache=True + ) if len(trials) == 0: return False @@ -118,7 +120,7 @@ def _determine_multi_objective_sampler( if isinstance(self._sampler, NSGAIISampler): return self._sampler - seed = self._rng.rng.randint(MAXINT32) + seed = self._rng.rng.randint(_MAXINT32) return NSGAIISampler(constraints_func=self._constraints_func, seed=seed) def _determine_single_objective_sampler( @@ -127,7 +129,7 @@ def _determine_single_objective_sampler( if isinstance(self._sampler, TPESampler): return self._sampler - seed = self._rng.rng.randint(MAXINT32) + seed = self._rng.rng.randint(_MAXINT32) if ( self._constraints_func is not None or any(isinstance(d, CategoricalDistribution) for d in search_space.values()) @@ -144,7 +146,9 @@ def _determine_single_objective_sampler( constant_liar=True, ) - complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)) + complete_trials = study._get_trials( + deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True + ) complete_trials.sort(key=lambda trial: trial.datetime_complete) if len(complete_trials) < self._N_COMPLETE_TRIALS_FOR_CMAES: # Use ``GPSampler`` if search space is numerical and @@ -195,13 +199,16 @@ def before_trial(self, study: Study, trial: FrozenTrial) -> None: # NOTE(nabenabe): Sampler must be updated in this method. If, for example, it is updated in # infer_relative_search_space, the sampler for before_trial and that for sample_relative, # after_trial might be different, meaning that the sampling routine could be incompatible. - if len(study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,))) != 0: + if ( + len(study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)) + != 0 + ): search_space = IntersectionSearchSpace().calculate(study) self._sampler = self._determine_sampler(study, trial, search_space) sampler_name = self._sampler.__class__.__name__ _logger.debug(f"Sample trial#{trial.number} with {sampler_name}.") - study._storage.set_trial_system_attr(trial._trial_id, SAMPLER_KEY, sampler_name) + study._storage.set_trial_system_attr(trial._trial_id, _SAMPLER_KEY, sampler_name) self._sampler.before_trial(study, trial) def after_trial(