Skip to content

Commit

Permalink
pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
btjanaka committed Nov 8, 2024
1 parent c6b88b7 commit e02890d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
1 change: 1 addition & 0 deletions package/samplers/cmamae/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .sampler import CmaMaeSampler


__all__ = ["CmaMaeSampler"]
11 changes: 4 additions & 7 deletions package/samplers/cmamae/example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import optuna
import optunahub
from optuna.study import StudyDirection

from sampler import CmaMaeSampler


# TODO: Replace above import with this.
# module = optunahub.load_module("samplers/pyribs")
# PyribsSampler = module.PyribsSampler


def objective(trial: optuna.trial.Trial) -> float:
def objective(trial: optuna.trial.Trial) -> tuple[float, float, float]:
"""Returns an objective followed by two measures."""
x = trial.suggest_float("x", -10, 10)
y = trial.suggest_float("y", -10, 10)
return -(x**2 + y**2) + 2, x, y
Expand All @@ -23,10 +23,7 @@ def objective(trial: optuna.trial.Trial) -> float:
archive_learning_rate=0.1,
archive_threshold_min=-10,
n_emitters=1,
emitter_x0={
"x": 5,
"y": 5
},
emitter_x0={"x": 5, "y": 5},
emitter_sigma0=0.1,
emitter_batch_size=5,
)
Expand Down
37 changes: 21 additions & 16 deletions package/samplers/cmamae/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from typing import Iterable

import numpy as np
import optunahub
from optuna.distributions import BaseDistribution, FloatDistribution
from optuna.distributions import BaseDistribution
from optuna.distributions import FloatDistribution
from optuna.study import Study
from optuna.trial import FrozenTrial, TrialState
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
import optunahub
from ribs.archives import GridArchive
from ribs.emitters import EvolutionStrategyEmitter
from ribs.schedulers import Scheduler


SimpleBaseSampler = optunahub.load_module("samplers/simple").SimpleBaseSampler


Expand Down Expand Up @@ -69,14 +72,11 @@ def __init__(
emitter_sigma0: float,
emitter_batch_size: int,
) -> None:

self._validate_params(param_names, emitter_x0)
self._param_names = param_names[:]

# NOTE: SimpleBaseSampler must know Optuna search_space information.
search_space = {
name: FloatDistribution(-1e9, 1e9) for name in self._param_names
}
search_space = {name: FloatDistribution(-1e9, 1e9) for name in self._param_names}
super().__init__(search_space=search_space)

emitter_x0_np = self._convert_to_pyribs_params(emitter_x0)
Expand All @@ -102,7 +102,8 @@ def __init__(
selection_rule="mu",
restart_rule="basic",
batch_size=emitter_batch_size,
) for _ in range(n_emitters)
)
for _ in range(n_emitters)
]

# Number of solutions generated in each batch from pyribs.
Expand All @@ -114,10 +115,9 @@ def __init__(
result_archive=result_archive,
)

self._values_to_tell: list[list[float]] = []
self._values_to_tell: list[Sequence[float]] = []

def _validate_params(self, param_names: list[str],
emitter_x0: dict[str, float]) -> None:
def _validate_params(self, param_names: list[str], emitter_x0: dict[str, float]) -> None:
dim = len(param_names)
param_set = set(param_names)
if dim != len(param_set):
Expand All @@ -128,12 +128,15 @@ def _validate_params(self, param_names: list[str],
if set(param_names) != emitter_x0.keys():
raise ValueError(
"emitter_x0 does not contain the parameters listed in param_names. "
"Please provide an initial value for each parameter.")
"Please provide an initial value for each parameter."
)

def _validate_param_names(self, given_param_names: Iterable[str]) -> None:
if set(self._param_names) != set(given_param_names):
raise ValueError("The given param names must match the param names "
"initially passed to this sampler.")
raise ValueError(
"The given param names must match the param names "
"initially passed to this sampler."
)

def _convert_to_pyribs_params(self, params: dict[str, float]) -> np.ndarray:
np_params = np.empty(len(self._param_names), dtype=float)
Expand All @@ -148,8 +151,8 @@ def _convert_to_optuna_params(self, params: np.ndarray) -> dict[str, float]:
return dict_params

def sample_relative(
self, study: Study, trial: FrozenTrial,
search_space: dict[str, BaseDistribution]) -> dict[str, float]:
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> dict[str, float]:
self._validate_param_names(search_space.keys())

# Note: Batch optimization means we need to enqueue trials.
Expand Down Expand Up @@ -185,6 +188,8 @@ def after_trial(

# Tell the batch results to external sampler once the batch is ready.
values = np.asarray(self._values_to_tell)
# TODO: This assumes the objective is the first value while measures are
# the remaining values; we should document this somewhere.
self._scheduler.tell(objective=values[:, 0], measures=values[:, 1:])

# Empty the results.
Expand Down

0 comments on commit e02890d

Please sign in to comment.