Skip to content

Commit

Permalink
Update sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
btjanaka committed Nov 8, 2024
1 parent 85d7a31 commit c6b88b7
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 24 deletions.
25 changes: 19 additions & 6 deletions package/samplers/cmamae/example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import optuna
import optunahub
from optuna.study import StudyDirection

from sampler import CmaMaeSampler

Expand All @@ -21,17 +22,29 @@ def objective(trial: optuna.trial.Trial) -> float:
archive_ranges=[(-10, 10), (-10, 10)],
archive_learning_rate=0.1,
archive_threshold_min=-10,
n_emitters=15,
n_emitters=1,
emitter_x0={
"x": 5,
"y": 5
},
emitter_sigma0=0.1,
emitter_batch_size=36,
emitter_batch_size=5,
)
study = optuna.create_study(
sampler=sampler,
directions=[
# pyribs maximizes objectives.
StudyDirection.MAXIMIZE,
# The remaining values are measures, which do not have an
# optimization direction.
# TODO: Currently, using StudyDirection.NOT_SET is not allowed as
# Optuna assumes we either minimize or maximize.
StudyDirection.MINIMIZE,
StudyDirection.MINIMIZE,
],
)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=100)
print(study.best_trial.params)

fig = optuna.visualization.plot_optimization_history(study)
fig.write_image("cmamae_optimization_history.png")
# TODO: Visualization.
# fig = optuna.visualization.plot_optimization_history(study)
# fig.write_image("cmamae_optimization_history.png")
62 changes: 44 additions & 18 deletions package/samplers/cmamae/sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Iterable

import numpy as np
import optunahub
from optuna.distributions import BaseDistribution
from optuna.distributions import BaseDistribution, FloatDistribution
from optuna.study import Study
from optuna.trial import FrozenTrial, TrialState
from ribs.archives import GridArchive
Expand Down Expand Up @@ -68,11 +69,16 @@ def __init__(
emitter_sigma0: float,
emitter_batch_size: int,
) -> None:
super().__init__()

self._validate_params(param_names, emitter_x0)
self._param_names = param_names[:]

# NOTE: SimpleBaseSampler must know Optuna search_space information.
search_space = {
name: FloatDistribution(-1e9, 1e9) for name in self._param_names
}
super().__init__(search_space=search_space)

emitter_x0_np = self._convert_to_pyribs_params(emitter_x0)

archive = GridArchive(
Expand Down Expand Up @@ -108,6 +114,8 @@ def __init__(
result_archive=result_archive,
)

self._values_to_tell: list[list[float]] = []

def _validate_params(self, param_names: list[str],
emitter_x0: dict[str, float]) -> None:
dim = len(param_names)
Expand All @@ -122,6 +130,11 @@ def _validate_params(self, param_names: list[str],
"emitter_x0 does not contain the parameters listed in param_names. "
"Please provide an initial value for each parameter.")

def _validate_param_names(self, given_param_names: Iterable[str]) -> None:
if set(self._param_names) != set(given_param_names):
raise ValueError("The given param names must match the param names "
"initially passed to this sampler.")

def _convert_to_pyribs_params(self, params: dict[str, float]) -> np.ndarray:
np_params = np.empty(len(self._param_names), dtype=float)
for i, p in enumerate(self._param_names):
Expand All @@ -137,17 +150,16 @@ def _convert_to_optuna_params(self, params: np.ndarray) -> dict[str, float]:
def sample_relative(
self, study: Study, trial: FrozenTrial,
search_space: dict[str, BaseDistribution]) -> dict[str, float]:
self._validate_param_names(search_space.keys())

# Note: Batch optimization means we need to enqueue trials.
# https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.enqueue_trial
if trial.number % self._batch_size == 0:
sols = self._scheduler.ask()
for sol in sols:
params = self._convert_to_optuna_params(sol)
study.enqueue_trial(params)
solutions = self._scheduler.ask()
next_params = self._convert_to_optuna_params(solutions[0])
for solution in solutions[1:]:
params = self._convert_to_optuna_params(solution)
study.enqueue_trial(params)

# Probably, this trial is taken from the queue, so we do not have to take it?
# but I need to look into it.
return trial
return next_params

def after_trial(
self,
Expand All @@ -156,10 +168,24 @@ def after_trial(
state: TrialState,
values: Sequence[float] | None,
) -> None:
# TODO
if trial.number % self._batch_size == self._batch_size - 1:
results = [
t.values[trial.number - self._batch_size + 1:trial.number + 1]
for t in study.trials
]
scheduler.tell
# TODO: Is it safe to assume the parameters will always come back in the
# order that they were sent out by the scheduler? Pyribs makes that
# assumption and stores the solutions internally. If not, maybe we can
# retrieve solutions based on their trial ID?

self._validate_param_names(trial.params.keys())

# Store the trial result.
self._values_to_tell.append(values)

# If we have not retrieved the whole batch of solutions, then we should
# not tell() the results to the scheduler yet.
if len(self._values_to_tell) != self._batch_size:
return

# Tell the batch results to external sampler once the batch is ready.
values = np.asarray(self._values_to_tell)
self._scheduler.tell(objective=values[:, 0], measures=values[:, 1:])

# Empty the results.
self._values_to_tell = []

0 comments on commit c6b88b7

Please sign in to comment.