Skip to content

Commit

Permalink
Merge pull request #176 from nabenabe0928/hotfix-smac
Browse files Browse the repository at this point in the history
Hotfix SMAC and use `system_attrs`
  • Loading branch information
y0z authored Nov 8, 2024
2 parents b883957 + 86c1819 commit 983540a
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions package/samplers/smac_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
from smac.runhistory.dataclasses import TrialValue
from smac.runhistory.enumerations import StatusType
from smac.scenario import Scenario
from smac.utils.configspace import get_config_hash


SimpleBaseSampler = optunahub.load_module("samplers/simple").SimpleBaseSampler
_SMAC_INSTANCE_KEY = "smac:instance"
_SMAC_SEED_KEY = "smac:seed"


class SMACSampler(SimpleBaseSampler): # type: ignore
Expand Down Expand Up @@ -149,9 +150,6 @@ def _dummmy_target_func(config: Configuration, seed: int = 0) -> float:
)
self.smac = smac

# Used to store the instance-seed pairs of each evaluated configurations.
self._runs_instance_seed_keys: dict[str, tuple[str | None, int]] = {}

def _get_surrogate_model(
self,
scenario: Scenario,
Expand Down Expand Up @@ -218,10 +216,10 @@ def sample_relative(
) -> dict[str, float]:
trial_info: TrialInfo = self.smac.ask()
cfg = trial_info.config
self._runs_instance_seed_keys[get_config_hash(cfg)] = (
trial_info.instance,
trial_info.seed,
study._storage.set_trial_system_attr(
trial._trial_id, _SMAC_INSTANCE_KEY, trial_info.instance
)
study._storage.set_trial_system_attr(trial._trial_id, _SMAC_SEED_KEY, trial_info.seed)
params = {}
for name, hp_value in cfg.items():
if name in self._hp_scale_value:
Expand All @@ -244,7 +242,9 @@ def after_trial(
cfg_params = {}
for name, hp_value in params.items():
if name in self._hp_scale_value:
hp_value = self._step_hp_to_intger(hp_value, scale_info=self._hp_scale_value[name])
hp_value = self._step_hp_to_integer(
hp_value, scale_info=self._hp_scale_value[name]
)
cfg_params[name] = hp_value

# params to smac HP, in SMAC, we always perform the minimization.
Expand All @@ -262,8 +262,8 @@ def after_trial(
trial_value = TrialValue(y, status=status)

cfg = Configuration(configuration_space=self._cs, values=cfg_params)
# Since Optuna does not provide us the
instance, seed = self._runs_instance_seed_keys[get_config_hash(cfg)]
instance = study._storage.get_trial_system_attrs(trial._trial_id).get(_SMAC_INSTANCE_KEY)
seed = study._storage.get_trial_system_attrs(trial._trial_id).get(_SMAC_SEED_KEY)
info = TrialInfo(cfg, seed=seed, instance=instance)
self.smac.tell(info=info, value=trial_value, save=False)

Expand Down

0 comments on commit 983540a

Please sign in to comment.