Skip to content

Commit

Permalink
Make sampler thread local
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 25, 2024
1 parent 5b9c9a7 commit 7e75683
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion package/samplers/auto_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Callable
from collections.abc import Sequence
import threading
from typing import Any
from typing import TYPE_CHECKING

Expand All @@ -28,6 +29,12 @@
SAMPLER_KEY = "auto:sampler"


class _ThreadLocalSampler(threading.local):
def __init__(self, sampler: BaseSampler) -> None:
self._sampler = sampler
super().__init__()


class AutoSampler(BaseSampler):
_N_COMPLETE_TRIALS_FOR_CMAES = 250

Expand Down Expand Up @@ -79,9 +86,18 @@ def __init__(
) -> None:
self._rng = LazyRandomState(seed)
seed_for_random_sampler = self._rng.rng.randint(MAXINT32)
self._sampler: BaseSampler = RandomSampler(seed=seed_for_random_sampler)
sampler: BaseSampler = RandomSampler(seed=seed_for_random_sampler)
self._thread_local_sampler = _ThreadLocalSampler(sampler)
self._constraints_func = constraints_func

@property
def _sampler(self) -> BaseSampler:
return self._thread_local_sampler._sampler

@_sampler.setter
def _sampler(self, sampler: BaseSampler) -> None:
self._thread_local_sampler._sampler = sampler

def reseed_rng(self) -> None:
self._rng.rng.seed()
self._sampler.reseed_rng()
Expand Down

0 comments on commit 7e75683

Please sign in to comment.