From 39080b43b4e2ad5ba221cae322337b1b15a1ba1e Mon Sep 17 00:00:00 2001 From: y0z Date: Fri, 21 Jun 2024 19:32:36 +0900 Subject: [PATCH] Fix hebo --- package/samplers/hebo/requirements.txt | 2 ++ package/samplers/hebo/sampler.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/package/samplers/hebo/requirements.txt b/package/samplers/hebo/requirements.txt index a9121a3e..67584885 100644 --- a/package/samplers/hebo/requirements.txt +++ b/package/samplers/hebo/requirements.txt @@ -1,3 +1,5 @@ +numpy optuna optunahub +pandas hebo@git+https://github.com/huawei-noah/HEBO.git#subdirectory=HEBO diff --git a/package/samplers/hebo/sampler.py b/package/samplers/hebo/sampler.py index 35bfcdde..e206beaa 100644 --- a/package/samplers/hebo/sampler.py +++ b/package/samplers/hebo/sampler.py @@ -1,13 +1,19 @@ from __future__ import annotations +from typing import Optional +from typing import Sequence + from optuna.distributions import BaseDistribution from optuna.distributions import CategoricalDistribution from optuna.distributions import FloatDistribution from optuna.distributions import IntDistribution from optuna.study import Study -from optuna.trial import FrozenTrial +from optuna.trial import FrozenTrial, TrialState import optunahub +import numpy as np +import pandas as pd + from hebo.design_space.design_space import DesignSpace from hebo.optimizers.hebo import HEBO @@ -30,6 +36,15 @@ def sample_relative( params[name] = params_pd[name].to_numpy()[0] return params + def after_trial( + self, + study: Study, + trial: FrozenTrial, + state: TrialState, + values: Optional[Sequence[float]], + ) -> None: + self._hebo.observe(pd.DataFrame([trial.params]), np.asarray([values])) + def _convert_to_hebo_design_space( self, search_space: dict[str, BaseDistribution] ) -> DesignSpace: