Skip to content

Commit

Permalink
Apply not's comment
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 29, 2024
1 parent a2a1896 commit d56cbe3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 5 additions & 2 deletions package/samplers/auto_sampler/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

_MAXINT32 = (1 << 31) - 1
_SAMPLER_KEY = "auto:sampler"
# NOTE(nabenabe): The prefix `optuna.` enables us to use the optuna logger externally.
_logger = get_logger(f"optuna.{__name__}")


Expand All @@ -52,7 +53,8 @@ class AutoSampler(BaseSampler):
def objective(trial):
x = trial.suggest_float("x", -5, 5)
return x**2
y = trial.suggest_float("y", -5, 5)
return x**2 + y**2
module = optunahub.load_module("samplers/auto_sampler")
study = optuna.create_study(sampler=module.AutoSampler())
Expand Down Expand Up @@ -120,6 +122,7 @@ def _include_conditional_param(self, study: Study) -> bool:
def _determine_multi_objective_sampler(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> BaseSampler:
# TODO(nabenabe): Add more efficient heuristic for MO.
if isinstance(self._sampler, NSGAIISampler):
return self._sampler

Expand Down Expand Up @@ -152,7 +155,6 @@ def _determine_single_objective_sampler(
complete_trials = study._get_trials(
deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True
)
complete_trials.sort(key=lambda trial: trial.datetime_complete)
if len(complete_trials) < self._N_COMPLETE_TRIALS_FOR_CMAES:
# Use ``GPSampler`` if search space is numerical and
# len(complete_trials) < _N_COMPLETE_TRIALS_FOR_CMAES.
Expand All @@ -162,6 +164,7 @@ def _determine_single_objective_sampler(
# Use ``CmaEsSampler`` if search space is numerical and
# len(complete_trials) > _N_COMPLETE_TRIALS_FOR_CMAES.
# Warm start CMA-ES with the first _N_COMPLETE_TRIALS_FOR_CMAES complete trials.
complete_trials.sort(key=lambda trial: trial.datetime_complete)
warm_start_trials = complete_trials[: self._N_COMPLETE_TRIALS_FOR_CMAES]
# NOTE(nabenabe): ``CmaEsSampler`` internally falls back to ``RandomSampler`` for
# 1D problems.
Expand Down
1 change: 0 additions & 1 deletion package/samplers/auto_sampler/tests/test_auto_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


# TODO(nabaenabe): Add the CI for this sampler.
# optuna.logging.set_verbosity(optuna.logging.CRITICAL)

AutoSampler = optunahub.load_local_module(
package="samplers/auto_sampler", registry_root="package/"
Expand Down

0 comments on commit d56cbe3

Please sign in to comment.