Skip to content

Commit

Permalink
Enable cache
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 29, 2024
1 parent 5304f9d commit 3d4dd5e
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion package/samplers/auto_sampler/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from multiprocessing.managers import DictProxy
import os
from typing import Any
from unittest.mock import Mock
import warnings

from _pytest.fixtures import SubRequest
Expand Down Expand Up @@ -704,11 +705,17 @@ def test_cache_is_invalidated(
n_jobs: int, relative_sampler_class: Callable[[], BaseSampler]
) -> None:
sampler = relative_sampler_class()
original_before_trial = sampler.before_trial

def mock_before_trial(study: Study, trial: FrozenTrial) -> None:
assert study._thread_local.cached_all_trials is None
original_before_trial(study, trial)

sampler.before_trial = Mock(side_effect=mock_before_trial)
study = optuna.study.create_study(sampler=sampler)

def objective(trial: Trial) -> float:
assert trial._relative_params is None
assert study._thread_local.cached_all_trials is None

trial.suggest_float("x", -10, 10)
trial.suggest_float("y", -10, 10)
Expand Down

0 comments on commit 3d4dd5e

Please sign in to comment.