diff --git a/package/samplers/cmamae/sampler.py b/package/samplers/cmamae/sampler.py index 3557f7de..15e0e4f2 100644 --- a/package/samplers/cmamae/sampler.py +++ b/package/samplers/cmamae/sampler.py @@ -113,7 +113,8 @@ def __init__( # Number of solutions generated in each batch from pyribs. self._batch_size = n_emitters * emitter_batch_size - self._scheduler = Scheduler( + # Public to allow access for, e.g., visualization. + self.scheduler = Scheduler( archive, emitters, result_archive=result_archive, @@ -161,7 +162,7 @@ def sample_relative( self._validate_param_names(search_space.keys()) # Note: Batch optimization means we need to enqueue trials. - solutions = self._scheduler.ask() + solutions = self.scheduler.ask() next_params = self._convert_to_optuna_params(solutions[0]) for solution in solutions[1:]: params = self._convert_to_optuna_params(solution) @@ -196,7 +197,7 @@ def after_trial( # Tell the batch results to external sampler once the batch is ready. values_to_tell = np.asarray(self._values_to_tell)[np.argsort(self._stored_trial_numbers)] - self._scheduler.tell(objective=values_to_tell[:, 0], measures=values_to_tell[:, 1:]) + self.scheduler.tell(objective=values_to_tell[:, 0], measures=values_to_tell[:, 1:]) # Empty the results. self._values_to_tell = []