Skip to content

Commit

Permalink
Fix local thread and its tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 29, 2024
1 parent c2be135 commit 9a4d0d9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
11 changes: 7 additions & 4 deletions package/samplers/auto_sampler/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,18 @@ def __init__(
constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None,
) -> None:
self._rng = LazyRandomState(seed)
seed_for_random_sampler = self._rng.rng.randint(_MAXINT32)
sampler: BaseSampler = RandomSampler(seed=seed_for_random_sampler)
self._thread_local_sampler = _ThreadLocalSampler()
self._thread_local_sampler._sampler = sampler
self._constraints_func = constraints_func

@property
def _sampler(self) -> BaseSampler:
assert self._sampler is not None
if self._thread_local_sampler._sampler is None:
# NOTE(nabenabe): Do not do this process in the __init__ method because the
# substitution at the init does not update attributes in self._thread_local_sampler
# in each thread.
seed_for_random_sampler = self._rng.rng.randint(_MAXINT32)
self._sampler = RandomSampler(seed=seed_for_random_sampler)

return self._thread_local_sampler._sampler

@_sampler.setter
Expand Down
12 changes: 12 additions & 0 deletions package/samplers/auto_sampler/tests/test_auto_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,15 @@ def test_choose_tpe_with_conditional_params() -> None:
assert ["RandomSampler"] + ["GPSampler"] * 15 + ["TPESampler"] * (
n_trials - 16
) == sampler_names


def test_multi_thread() -> None:
n_trials = 30
auto_sampler = AutoSampler()
auto_sampler._N_COMPLETE_TRIALS_FOR_CMAES = 10
study = optuna.create_study(sampler=auto_sampler)
study.optimize(objective, n_trials=n_trials)
sampler_names = _get_used_sampler_names(study)
assert "RandomSampler" in sampler_names
assert "GPSampler" in sampler_names
assert "CmaEsSampler" in sampler_names

0 comments on commit 9a4d0d9

Please sign in to comment.