From 86c1819aa8661d98bccfb757d3d7884e965257f4 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Fri, 8 Nov 2024 08:46:34 +0100 Subject: [PATCH] Hotfix smac and use system_attrs --- package/samplers/smac_sampler/sampler.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/package/samplers/smac_sampler/sampler.py b/package/samplers/smac_sampler/sampler.py index ababce2f..10c43204 100644 --- a/package/samplers/smac_sampler/sampler.py +++ b/package/samplers/smac_sampler/sampler.py @@ -32,10 +32,11 @@ from smac.runhistory.dataclasses import TrialValue from smac.runhistory.enumerations import StatusType from smac.scenario import Scenario -from smac.utils.configspace import get_config_hash SimpleBaseSampler = optunahub.load_module("samplers/simple").SimpleBaseSampler +_SMAC_INSTANCE_KEY = "smac:instance" +_SMAC_SEED_KEY = "smac:seed" class SMACSampler(SimpleBaseSampler): # type: ignore @@ -149,9 +150,6 @@ def _dummmy_target_func(config: Configuration, seed: int = 0) -> float: ) self.smac = smac - # Used to store the instance-seed pairs of each evaluated configurations. - self._runs_instance_seed_keys: dict[str, tuple[str | None, int]] = {} - def _get_surrogate_model( self, scenario: Scenario, @@ -218,10 +216,10 @@ def sample_relative( ) -> dict[str, float]: trial_info: TrialInfo = self.smac.ask() cfg = trial_info.config - self._runs_instance_seed_keys[get_config_hash(cfg)] = ( - trial_info.instance, - trial_info.seed, + study._storage.set_trial_system_attr( + trial._trial_id, _SMAC_INSTANCE_KEY, trial_info.instance ) + study._storage.set_trial_system_attr(trial._trial_id, _SMAC_SEED_KEY, trial_info.seed) params = {} for name, hp_value in cfg.items(): if name in self._hp_scale_value: @@ -244,7 +242,9 @@ def after_trial( cfg_params = {} for name, hp_value in params.items(): if name in self._hp_scale_value: - hp_value = self._step_hp_to_intger(hp_value, scale_info=self._hp_scale_value[name]) + hp_value = self._step_hp_to_integer( + hp_value, scale_info=self._hp_scale_value[name] + ) cfg_params[name] = hp_value # params to smac HP, in SMAC, we always perform the minimization. @@ -262,8 +262,8 @@ def after_trial( trial_value = TrialValue(y, status=status) cfg = Configuration(configuration_space=self._cs, values=cfg_params) - # Since Optuna does not provide us the - instance, seed = self._runs_instance_seed_keys[get_config_hash(cfg)] + instance = study._storage.get_trial_system_attrs(trial._trial_id).get(_SMAC_INSTANCE_KEY) + seed = study._storage.get_trial_system_attrs(trial._trial_id).get(_SMAC_SEED_KEY) info = TrialInfo(cfg, seed=seed, instance=instance) self.smac.tell(info=info, value=trial_value, save=False)