From 738d45281b498657bc463be586edd232e3ff7719 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Fri, 13 Dec 2024 04:19:23 +0100 Subject: [PATCH] Add validation of dists --- package/samplers/hebo/sampler.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/package/samplers/hebo/sampler.py b/package/samplers/hebo/sampler.py index d1a5a1f..18279bd 100644 --- a/package/samplers/hebo/sampler.py +++ b/package/samplers/hebo/sampler.py @@ -26,6 +26,27 @@ _logger = get_logger(f"optuna.{__name__}") +def _validate_distributions( + param_name: str, + target_distribution: FloatDistribution | IntDistribution, + trials: list[FrozenTrial], +) -> None: + assert not target_distribution.log and target_distribution.step is not None + dists = [t.distributions[param_name] for t in trials] + lows = np.asarray([d.low for d in dists]) + highs = np.asarray([d.high for d in dists]) + steps = np.asarray([d.step if d.step is not None else np.nan for d in dists]) + if ( + not np.allclose(target_distribution.low, lows) + or not np.allclose(target_distribution.high, highs) + or not np.allclose(target_distribution.step, steps) + ): + raise ValueError( + "When using the `step` option or `suggest_int`, `low`, `high`, and `step` cannot be " + f"modified during study, but modifications were detected in `{param_name}`." + ) + + class HEBOSampler(optunahub.samplers.SimpleBaseSampler): """A sampler using `HEBO __` as the backend. @@ -134,6 +155,7 @@ def _transform_to_dict_and_observe( and not dist.log and dist.step is not None ): + _validate_distributions(name, dist, trials) params[name] = np.round((params[name] - dist.low) / dist.step).astype(int) hebo.observe(params, nan_padded_values) @@ -161,7 +183,6 @@ def _sample_relative_stateless( return {} trials = [t for t in trials if set(search_space.keys()) <= set(t.params.keys())] - seed = int(self._rng.integers(low=1, high=(1 << 31))) hebo = HEBO(self._convert_to_hebo_design_space(search_space), scramble_seed=seed) self._transform_to_dict_and_observe(hebo, search_space, study, trials)