Skip to content

Commit

Permalink
Add the trick to handle dynamic search space
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Dec 13, 2024
1 parent 738d452 commit f5285e0
Showing 1 changed file with 4 additions and 23 deletions.
27 changes: 4 additions & 23 deletions package/samplers/hebo/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,6 @@
_logger = get_logger(f"optuna.{__name__}")


def _validate_distributions(
param_name: str,
target_distribution: FloatDistribution | IntDistribution,
trials: list[FrozenTrial],
) -> None:
assert not target_distribution.log and target_distribution.step is not None
dists = [t.distributions[param_name] for t in trials]
lows = np.asarray([d.low for d in dists])
highs = np.asarray([d.high for d in dists])
steps = np.asarray([d.step if d.step is not None else np.nan for d in dists])
if (
not np.allclose(target_distribution.low, lows)
or not np.allclose(target_distribution.high, highs)
or not np.allclose(target_distribution.step, steps)
):
raise ValueError(
"When using the `step` option or `suggest_int`, `low`, `high`, and `step` cannot be "
f"modified during study, but modifications were detected in `{param_name}`."
)


class HEBOSampler(optunahub.samplers.SimpleBaseSampler):
"""A sampler using `HEBO <https://github.com/huawei-noah/HEBO/tree/master/HEBO>__` as the backend.
Expand Down Expand Up @@ -155,8 +134,10 @@ def _transform_to_dict_and_observe(
and not dist.log
and dist.step is not None
):
_validate_distributions(name, dist, trials)
params[name] = np.round((params[name] - dist.low) / dist.step).astype(int)
# NOTE(nabenabe): We do not round here because HEBO treats params as float even if
# the domain is defined on integer. By not rounding, HEBO can handle any changes in
# the domain of these parameters such as changes in low, high, and step.
params[name] = (params[name] - dist.low) / dist.step

hebo.observe(params, nan_padded_values)

Expand Down

0 comments on commit f5285e0

Please sign in to comment.