diff --git a/package/samplers/hebo/requirements.txt b/package/samplers/hebo/requirements.txt index a9121a3e..67584885 100644 --- a/package/samplers/hebo/requirements.txt +++ b/package/samplers/hebo/requirements.txt @@ -1,3 +1,5 @@ +numpy optuna optunahub +pandas hebo@git+https://github.com/huawei-noah/HEBO.git#subdirectory=HEBO diff --git a/package/samplers/hebo/sampler.py b/package/samplers/hebo/sampler.py index 35bfcdde..e206beaa 100644 --- a/package/samplers/hebo/sampler.py +++ b/package/samplers/hebo/sampler.py @@ -1,13 +1,19 @@ from __future__ import annotations +from typing import Optional +from typing import Sequence + from optuna.distributions import BaseDistribution from optuna.distributions import CategoricalDistribution from optuna.distributions import FloatDistribution from optuna.distributions import IntDistribution from optuna.study import Study -from optuna.trial import FrozenTrial +from optuna.trial import FrozenTrial, TrialState import optunahub +import numpy as np +import pandas as pd + from hebo.design_space.design_space import DesignSpace from hebo.optimizers.hebo import HEBO @@ -30,6 +36,15 @@ def sample_relative( params[name] = params_pd[name].to_numpy()[0] return params + def after_trial( + self, + study: Study, + trial: FrozenTrial, + state: TrialState, + values: Optional[Sequence[float]], + ) -> None: + self._hebo.observe(pd.DataFrame([trial.params]), np.asarray([values])) + def _convert_to_hebo_design_space( self, search_space: dict[str, BaseDistribution] ) -> DesignSpace: