Skip to content

Commit

Permalink
Call internal sampler's before_trial
Browse files Browse the repository at this point in the history
  • Loading branch information
not522 authored and gen740 committed Sep 22, 2023
1 parent 76381ba commit a920a91
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optuna/integration/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,9 @@ def reseed_rng(self) -> None:
if self._seed is not None:
self._seed = numpy.random.RandomState().randint(numpy.iinfo(numpy.int32).max)

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
self._independent_sampler.before_trial(study, trial)

def after_trial(
self,
study: Study,
Expand Down
3 changes: 3 additions & 0 deletions optuna/integration/cma.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ def _log_independent_sampling(self, trial: FrozenTrial, param_name: str) -> None
)
)

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
self._independent_sampler.before_trial(study, trial)

def after_trial(
self,
study: Study,
Expand Down
5 changes: 5 additions & 0 deletions optuna/samplers/_cmaes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import math
import pickle
Expand Down Expand Up @@ -779,6 +781,9 @@ def _get_solution_trials(
generation_attr_key = self._attr_keys.generation(n_restarts)
return [t for t in trials if generation == t.system_attrs.get(generation_attr_key, -1)]

def before_trial(self, study: optuna.Study, trial: FrozenTrial) -> None:
self._independent_sampler.before_trial(study, trial)

def after_trial(
self,
study: "optuna.Study",
Expand Down
3 changes: 3 additions & 0 deletions optuna/samplers/_nsgaiii.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ def _select_elite_population(
break
return elite_population

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
self._random_sampler.before_trial(study, trial)

def after_trial(
self,
study: Study,
Expand Down
3 changes: 3 additions & 0 deletions optuna/samplers/_partial_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def sample_independent(
)
return param_value

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
self._base_sampler.before_trial(study, trial)

def after_trial(
self,
study: Study,
Expand Down
3 changes: 3 additions & 0 deletions optuna/samplers/_qmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def sample_relative(
sample = trans.bounds[:, 0] + sample * (trans.bounds[:, 1] - trans.bounds[:, 0])
return trans.untransform(sample[0, :])

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
self._independent_sampler.before_trial(study, trial)

def after_trial(
self,
study: "optuna.Study",
Expand Down
3 changes: 3 additions & 0 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ def objective(trial):
"weights": default_weights,
}

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
self._random_sampler.before_trial(study, trial)

def after_trial(
self,
study: Study,
Expand Down
3 changes: 3 additions & 0 deletions optuna/samplers/nsgaii/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def _collect_parent_population(self, study: Study) -> tuple[int, list[FrozenTria

return parent_generation, parent_population

def before_trial(self, study: Study, trial: FrozenTrial) -> None:
self._random_sampler.before_trial(study, trial)

def after_trial(
self,
study: Study,
Expand Down

0 comments on commit a920a91

Please sign in to comment.