diff --git a/package/samplers/auto_sampler/sampler.py b/package/samplers/auto_sampler/sampler.py index d4c15539..f2bf676e 100644 --- a/package/samplers/auto_sampler/sampler.py +++ b/package/samplers/auto_sampler/sampler.py @@ -211,8 +211,7 @@ def sample_independent( return self._sampler.sample_independent(study, trial, param_name, param_distribution) def before_trial(self, study: Study, trial: FrozenTrial) -> None: - # NOTE(nabenabe): Use the states used in IntersectionSearchSpace().calculate. - states_of_interest = [TrialState.COMPLETE, TrialState.WAITING, TrialState.RUNNING] + states_of_interest = [TrialState.COMPLETE, TrialState.WAITING] if len(study._get_trials(deepcopy=False, states=states_of_interest)) != 0: search_space = IntersectionSearchSpace().calculate(study) self._sampler = self._determine_sampler(study, trial, search_space) diff --git a/package/samplers/auto_sampler/tests/test_auto_sampler.py b/package/samplers/auto_sampler/tests/test_auto_sampler.py index b2d45620..901526a4 100644 --- a/package/samplers/auto_sampler/tests/test_auto_sampler.py +++ b/package/samplers/auto_sampler/tests/test_auto_sampler.py @@ -144,3 +144,6 @@ def test_choose_tpe_with_conditional_params() -> None: assert ["RandomSampler"] + ["GPSampler"] * 15 + ["TPESampler"] * ( n_trials - 16 ) == sampler_names + + +# TODO: Add a test with enqueue_trial.