From d37f6c9e01583f45c81150f30feb57c8f7c1bdc1 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Wed, 30 Oct 2024 05:44:05 +0100 Subject: [PATCH] Apply c-bata's comments --- package/samplers/auto_sampler/README.md | 2 +- package/samplers/auto_sampler/_sampler.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/package/samplers/auto_sampler/README.md b/package/samplers/auto_sampler/README.md index 1d726e74..96bddde9 100644 --- a/package/samplers/auto_sampler/README.md +++ b/package/samplers/auto_sampler/README.md @@ -54,5 +54,5 @@ pip install pytest ``` ```python -python -m pytest package/samplers/auto_sampler/tests/test_auto_sampler.py +pytest package/samplers/auto_sampler/tests/ ``` diff --git a/package/samplers/auto_sampler/_sampler.py b/package/samplers/auto_sampler/_sampler.py index daa18160..85aa971d 100644 --- a/package/samplers/auto_sampler/_sampler.py +++ b/package/samplers/auto_sampler/_sampler.py @@ -32,8 +32,8 @@ _logger = get_logger(f"optuna.{__name__}") -class _ThreadLocalSampler(threading.local): - _sampler: BaseSampler | None = None +class ThreadLocalSampler(threading.local): + sampler: BaseSampler | None = None class AutoSampler(BaseSampler): @@ -87,23 +87,23 @@ def __init__( constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None, ) -> None: self._rng = LazyRandomState(seed) - self._thread_local_sampler = _ThreadLocalSampler() + self._thread_local_sampler = ThreadLocalSampler() self._constraints_func = constraints_func @property def _sampler(self) -> BaseSampler: - if self._thread_local_sampler._sampler is None: + if self._thread_local_sampler.sampler is None: # NOTE(nabenabe): Do not do this process in the __init__ method because the # substitution at the init does not update attributes in self._thread_local_sampler # in each thread. seed_for_random_sampler = self._rng.rng.randint(_MAXINT32) self._sampler = RandomSampler(seed=seed_for_random_sampler) - return self._thread_local_sampler._sampler + return self._thread_local_sampler.sampler @_sampler.setter def _sampler(self, sampler: BaseSampler) -> None: - self._thread_local_sampler._sampler = sampler + self._thread_local_sampler.sampler = sampler def reseed_rng(self) -> None: self._rng.rng.seed()