Skip to content

Commit

Permalink
Apply ozaki's comment
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 31, 2024
1 parent 13e63a6 commit 396b703
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions package/samplers/auto_sampler/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from multiprocessing.managers import DictProxy
import os
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
import warnings

from _pytest.fixtures import SubRequest
Expand Down Expand Up @@ -677,15 +677,15 @@ def mock_before_trial(study: Study, trial: FrozenTrial) -> None:
assert study._thread_local.cached_all_trials is None
original_before_trial(study, trial)

sampler.before_trial = Mock(side_effect=mock_before_trial)
study = optuna.study.create_study(sampler=sampler)
with patch.object(sampler, "before_trial", side_effect=mock_before_trial):
study = optuna.study.create_study(sampler=sampler)

def objective(trial: Trial) -> float:
assert trial._relative_params is None
def objective(trial: Trial) -> float:
assert trial._relative_params is None

trial.suggest_float("x", -10, 10)
trial.suggest_float("y", -10, 10)
assert trial._relative_params is not None
return -1
trial.suggest_float("x", -10, 10)
trial.suggest_float("y", -10, 10)
assert trial._relative_params is not None
return -1

study.optimize(objective, n_trials=10, n_jobs=n_jobs)
study.optimize(objective, n_trials=10, n_jobs=n_jobs)

0 comments on commit 396b703

Please sign in to comment.