diff --git a/package/samplers/auto_sampler/tests/test_sampler.py b/package/samplers/auto_sampler/tests/test_sampler.py index 039c8fc8..470c47e4 100644 --- a/package/samplers/auto_sampler/tests/test_sampler.py +++ b/package/samplers/auto_sampler/tests/test_sampler.py @@ -6,6 +6,7 @@ from multiprocessing.managers import DictProxy import os from typing import Any +from unittest.mock import Mock import warnings from _pytest.fixtures import SubRequest @@ -704,11 +705,17 @@ def test_cache_is_invalidated( n_jobs: int, relative_sampler_class: Callable[[], BaseSampler] ) -> None: sampler = relative_sampler_class() + original_before_trial = sampler.before_trial + + def mock_before_trial(study: Study, trial: FrozenTrial) -> None: + assert study._thread_local.cached_all_trials is None + original_before_trial(study, trial) + + sampler.before_trial = Mock(side_effect=mock_before_trial) study = optuna.study.create_study(sampler=sampler) def objective(trial: Trial) -> float: assert trial._relative_params is None - assert study._thread_local.cached_all_trials is None trial.suggest_float("x", -10, 10) trial.suggest_float("y", -10, 10)