From a920a91e4e5a32ade707f3e9819fff0f96cc7946 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Tue, 12 Sep 2023 14:47:03 +0900 Subject: [PATCH] Call internal sampler's before_trial --- optuna/integration/botorch.py | 3 +++ optuna/integration/cma.py | 3 +++ optuna/samplers/_cmaes.py | 5 +++++ optuna/samplers/_nsgaiii.py | 3 +++ optuna/samplers/_partial_fixed.py | 3 +++ optuna/samplers/_qmc.py | 3 +++ optuna/samplers/_tpe/sampler.py | 3 +++ optuna/samplers/nsgaii/_sampler.py | 3 +++ 8 files changed, 26 insertions(+) diff --git a/optuna/integration/botorch.py b/optuna/integration/botorch.py index 54ee6c8fd6..2e49984b83 100644 --- a/optuna/integration/botorch.py +++ b/optuna/integration/botorch.py @@ -883,6 +883,9 @@ def reseed_rng(self) -> None: if self._seed is not None: self._seed = numpy.random.RandomState().randint(numpy.iinfo(numpy.int32).max) + def before_trial(self, study: Study, trial: FrozenTrial) -> None: + self._independent_sampler.before_trial(study, trial) + def after_trial( self, study: Study, diff --git a/optuna/integration/cma.py b/optuna/integration/cma.py index 0bce5db94c..24506b600d 100644 --- a/optuna/integration/cma.py +++ b/optuna/integration/cma.py @@ -291,6 +291,9 @@ def _log_independent_sampling(self, trial: FrozenTrial, param_name: str) -> None ) ) + def before_trial(self, study: Study, trial: FrozenTrial) -> None: + self._independent_sampler.before_trial(study, trial) + def after_trial( self, study: Study, diff --git a/optuna/samplers/_cmaes.py b/optuna/samplers/_cmaes.py index 920dcc48c6..9233f1219d 100644 --- a/optuna/samplers/_cmaes.py +++ b/optuna/samplers/_cmaes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import math import pickle @@ -779,6 +781,9 @@ def _get_solution_trials( generation_attr_key = self._attr_keys.generation(n_restarts) return [t for t in trials if generation == t.system_attrs.get(generation_attr_key, -1)] + def before_trial(self, study: optuna.Study, trial: FrozenTrial) -> None: + self._independent_sampler.before_trial(study, trial) + def after_trial( self, study: "optuna.Study", diff --git a/optuna/samplers/_nsgaiii.py b/optuna/samplers/_nsgaiii.py index 93b1a0b08c..2ba357745d 100644 --- a/optuna/samplers/_nsgaiii.py +++ b/optuna/samplers/_nsgaiii.py @@ -315,6 +315,9 @@ def _select_elite_population( break return elite_population + def before_trial(self, study: Study, trial: FrozenTrial) -> None: + self._random_sampler.before_trial(study, trial) + def after_trial( self, study: Study, diff --git a/optuna/samplers/_partial_fixed.py b/optuna/samplers/_partial_fixed.py index 050f2593d3..d8eb62f914 100644 --- a/optuna/samplers/_partial_fixed.py +++ b/optuna/samplers/_partial_fixed.py @@ -106,6 +106,9 @@ def sample_independent( ) return param_value + def before_trial(self, study: Study, trial: FrozenTrial) -> None: + self._base_sampler.before_trial(study, trial) + def after_trial( self, study: Study, diff --git a/optuna/samplers/_qmc.py b/optuna/samplers/_qmc.py index bce3d7fa3c..b9d15f6238 100644 --- a/optuna/samplers/_qmc.py +++ b/optuna/samplers/_qmc.py @@ -249,6 +249,9 @@ def sample_relative( sample = trans.bounds[:, 0] + sample * (trans.bounds[:, 1] - trans.bounds[:, 0]) return trans.untransform(sample[0, :]) + def before_trial(self, study: Study, trial: FrozenTrial) -> None: + self._independent_sampler.before_trial(study, trial) + def after_trial( self, study: "optuna.Study", diff --git a/optuna/samplers/_tpe/sampler.py b/optuna/samplers/_tpe/sampler.py index f0317d3c45..a725176c8d 100644 --- a/optuna/samplers/_tpe/sampler.py +++ b/optuna/samplers/_tpe/sampler.py @@ -553,6 +553,9 @@ def objective(trial): "weights": default_weights, } + def before_trial(self, study: Study, trial: FrozenTrial) -> None: + self._random_sampler.before_trial(study, trial) + def after_trial( self, study: Study, diff --git a/optuna/samplers/nsgaii/_sampler.py b/optuna/samplers/nsgaii/_sampler.py index bf03bd4eed..6c081a87c9 100644 --- a/optuna/samplers/nsgaii/_sampler.py +++ b/optuna/samplers/nsgaii/_sampler.py @@ -355,6 +355,9 @@ def _collect_parent_population(self, study: Study) -> tuple[int, list[FrozenTria return parent_generation, parent_population + def before_trial(self, study: Study, trial: FrozenTrial) -> None: + self._random_sampler.before_trial(study, trial) + def after_trial( self, study: Study,