Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Nov 27, 2024
1 parent 1d303b3 commit 8462391
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions package/samplers/cmamae/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def __init__(
self._validate_params(param_names, emitter_x0)
self._param_names = param_names.copy()
self._measure_names = measure_names.copy()
if len(set(self._measure_names)) != 2:
raise ValueError(
"measure_names must be a list of two unique measure names, "
f"but got measure_names={measure_names}."
)

# NOTE: SimpleBaseSampler must know Optuna search_space information.
search_space = {name: FloatDistribution(-1e9, 1e9) for name in self._param_names}
Expand Down Expand Up @@ -193,9 +198,9 @@ def after_trial(
user_attrs = trial.user_attrs
if any(measure_name not in user_attrs for measure_name in self._measure_names):
raise KeyError(
f"All of measure in measure_names={self._measure_names} must be set to "
"trial.user_attrs. Please call trial.set_user_attr(<measure_name>, <value>) "
"for each measure."
f"All of measures in measure_names={self._measure_names} must be set to "
"trial.user_attrs. Please call `trial.set_user_attr(<measure_name>, <value>)` "
"for each measure in your objective function."
)

self._raise_error_if_multi_objective(study)
Expand Down

0 comments on commit 8462391

Please sign in to comment.