From fd91c4c683fc32796e3901eeed2962b56c470c0f Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Tue, 1 Oct 2024 04:26:49 +0200 Subject: [PATCH] Fix the example --- package/samplers/user_prior_cmaes/README.md | 2 +- package/samplers/user_prior_cmaes/sampler.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/package/samplers/user_prior_cmaes/README.md b/package/samplers/user_prior_cmaes/README.md index d22b4790..01e03663 100644 --- a/package/samplers/user_prior_cmaes/README.md +++ b/package/samplers/user_prior_cmaes/README.md @@ -47,7 +47,7 @@ def objective(trial: optuna.Trial) -> float: if __name__ == "__main__": module = optunahub.load_module(package="samplers/user_prior_cmaes") - sampler = module.UserPriorCmaEsSampler(param_names=["x", "y"], mu0=np.array([3., -48.]), cov0=np.diag([0.2, 2.0])) + sampler = module.UserPriorCmaEsSampler(param_names=["x", "y"], mu0=np.array([-48., 3.]), cov0=np.diag([2., 0.2])) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=20) print(study.best_trial.value, study.best_trial.params) diff --git a/package/samplers/user_prior_cmaes/sampler.py b/package/samplers/user_prior_cmaes/sampler.py index 183facdc..5f6ce774 100644 --- a/package/samplers/user_prior_cmaes/sampler.py +++ b/package/samplers/user_prior_cmaes/sampler.py @@ -108,6 +108,7 @@ def sample_relative( "The most probable reason is duplicated names in param_names." ) elif len(search_space) != 0: + # Ensure the parameter order is identical to that in param_names. search_space = { param_name: search_space[param_name] for param_name in self._param_names }