Skip to content

Commit

Permalink
Add validation of dists
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Dec 13, 2024
1 parent df00903 commit 738d452
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion package/samplers/hebo/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@
_logger = get_logger(f"optuna.{__name__}")


def _validate_distributions(
param_name: str,
target_distribution: FloatDistribution | IntDistribution,
trials: list[FrozenTrial],
) -> None:
assert not target_distribution.log and target_distribution.step is not None
dists = [t.distributions[param_name] for t in trials]
lows = np.asarray([d.low for d in dists])
highs = np.asarray([d.high for d in dists])
steps = np.asarray([d.step if d.step is not None else np.nan for d in dists])
if (
not np.allclose(target_distribution.low, lows)
or not np.allclose(target_distribution.high, highs)
or not np.allclose(target_distribution.step, steps)
):
raise ValueError(
"When using the `step` option or `suggest_int`, `low`, `high`, and `step` cannot be "
f"modified during study, but modifications were detected in `{param_name}`."
)


class HEBOSampler(optunahub.samplers.SimpleBaseSampler):
"""A sampler using `HEBO <https://github.com/huawei-noah/HEBO/tree/master/HEBO>__` as the backend.
Expand Down Expand Up @@ -134,6 +155,7 @@ def _transform_to_dict_and_observe(
and not dist.log
and dist.step is not None
):
_validate_distributions(name, dist, trials)
params[name] = np.round((params[name] - dist.low) / dist.step).astype(int)

hebo.observe(params, nan_padded_values)
Expand Down Expand Up @@ -161,7 +183,6 @@ def _sample_relative_stateless(
return {}

trials = [t for t in trials if set(search_space.keys()) <= set(t.params.keys())]

seed = int(self._rng.integers(low=1, high=(1 << 31)))
hebo = HEBO(self._convert_to_hebo_design_space(search_space), scramble_seed=seed)
self._transform_to_dict_and_observe(hebo, search_space, study, trials)
Expand Down

0 comments on commit 738d452

Please sign in to comment.