Skip to content

Commit

Permalink
Apply c-bata's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 29, 2024
1 parent 05048c3 commit 5304f9d
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions package/samplers/auto_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from optuna.trial import FrozenTrial


MAXINT32 = (1 << 31) - 1
SAMPLER_KEY = "auto:sampler"
_MAXINT32 = (1 << 31) - 1
_SAMPLER_KEY = "auto:sampler"
_logger = get_logger(f"optuna.{__name__}")


Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None,
) -> None:
self._rng = LazyRandomState(seed)
seed_for_random_sampler = self._rng.rng.randint(MAXINT32)
seed_for_random_sampler = self._rng.rng.randint(_MAXINT32)
sampler: BaseSampler = RandomSampler(seed=seed_for_random_sampler)
self._thread_local_sampler = _ThreadLocalSampler(sampler)
self._constraints_func = constraints_func
Expand All @@ -105,7 +105,9 @@ def reseed_rng(self) -> None:
self._sampler.reseed_rng()

def _include_conditional_param(self, study: Study) -> bool:
trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE, TrialState.PRUNED))
trials = study._get_trials(
deepcopy=False, states=(TrialState.COMPLETE, TrialState.PRUNED), use_cache=True
)
if len(trials) == 0:
return False

Expand All @@ -118,7 +120,7 @@ def _determine_multi_objective_sampler(
if isinstance(self._sampler, NSGAIISampler):
return self._sampler

seed = self._rng.rng.randint(MAXINT32)
seed = self._rng.rng.randint(_MAXINT32)
return NSGAIISampler(constraints_func=self._constraints_func, seed=seed)

def _determine_single_objective_sampler(
Expand All @@ -127,7 +129,7 @@ def _determine_single_objective_sampler(
if isinstance(self._sampler, TPESampler):
return self._sampler

seed = self._rng.rng.randint(MAXINT32)
seed = self._rng.rng.randint(_MAXINT32)
if (
self._constraints_func is not None
or any(isinstance(d, CategoricalDistribution) for d in search_space.values())
Expand All @@ -144,7 +146,9 @@ def _determine_single_objective_sampler(
constant_liar=True,
)

complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
complete_trials = study._get_trials(
deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True
)
complete_trials.sort(key=lambda trial: trial.datetime_complete)
if len(complete_trials) < self._N_COMPLETE_TRIALS_FOR_CMAES:
# Use ``GPSampler`` if search space is numerical and
Expand Down Expand Up @@ -195,13 +199,16 @@ def before_trial(self, study: Study, trial: FrozenTrial) -> None:
# NOTE(nabenabe): Sampler must be updated in this method. If, for example, it is updated in
# infer_relative_search_space, the sampler for before_trial and that for sample_relative,
# after_trial might be different, meaning that the sampling routine could be incompatible.
if len(study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,))) != 0:
if (
len(study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True))
!= 0
):
search_space = IntersectionSearchSpace().calculate(study)
self._sampler = self._determine_sampler(study, trial, search_space)

sampler_name = self._sampler.__class__.__name__
_logger.debug(f"Sample trial#{trial.number} with {sampler_name}.")
study._storage.set_trial_system_attr(trial._trial_id, SAMPLER_KEY, sampler_name)
study._storage.set_trial_system_attr(trial._trial_id, _SAMPLER_KEY, sampler_name)
self._sampler.before_trial(study, trial)

def after_trial(
Expand Down

0 comments on commit 5304f9d

Please sign in to comment.