From ea89dd19d12d7b85b848717823b606d9895cb6ed Mon Sep 17 00:00:00 2001
From: nabenabe0928 <shuhei.watanabe.utokyo@gmail.com>
Date: Thu, 31 Oct 2024 08:44:44 +0100
Subject: [PATCH] Apply the feedback from the mob review

---
 package/samplers/auto_sampler/_sampler.py     | 27 ++++++++++++-------
 .../auto_sampler/tests/test_auto_sampler.py   | 20 ++++++++++++++
 2 files changed, 37 insertions(+), 10 deletions(-)

diff --git a/package/samplers/auto_sampler/_sampler.py b/package/samplers/auto_sampler/_sampler.py
index 28e52e4c..55669274 100644
--- a/package/samplers/auto_sampler/_sampler.py
+++ b/package/samplers/auto_sampler/_sampler.py
@@ -159,16 +159,23 @@ def _determine_single_objective_sampler(
             # len(complete_trials) < _N_COMPLETE_TRIALS_FOR_CMAES.
             if not isinstance(self._sampler, GPSampler):
                 return GPSampler(seed=seed)
-        elif not isinstance(self._sampler, CmaEsSampler):
-            # Use ``CmaEsSampler`` if search space is numerical and
-            # len(complete_trials) > _N_COMPLETE_TRIALS_FOR_CMAES.
-            # Warm start CMA-ES with the first _N_COMPLETE_TRIALS_FOR_CMAES complete trials.
-            complete_trials.sort(key=lambda trial: trial.datetime_complete)
-            warm_start_trials = complete_trials[: self._N_COMPLETE_TRIALS_FOR_CMAES]
-            # NOTE(nabenabe): ``CmaEsSampler`` internally falls back to ``RandomSampler`` for
-            # 1D problems.
-            return CmaEsSampler(
-                seed=seed, source_trials=warm_start_trials, warn_independent_sampling=True
+        elif len(search_space) > 1:
+            if not isinstance(self._sampler, CmaEsSampler):
+                # Use ``CmaEsSampler`` if search space is numerical and
+                # len(complete_trials) > _N_COMPLETE_TRIALS_FOR_CMAES.
+                # Warm start CMA-ES with the first _N_COMPLETE_TRIALS_FOR_CMAES complete trials.
+                complete_trials.sort(key=lambda trial: trial.datetime_complete)
+                warm_start_trials = complete_trials[: self._N_COMPLETE_TRIALS_FOR_CMAES]
+                return CmaEsSampler(
+                    seed=seed, source_trials=warm_start_trials, warn_independent_sampling=True
+                )
+        else:
+            return TPESampler(
+                seed=seed,
+                multivariate=True,
+                warn_independent_sampling=False,
+                constraints_func=self._constraints_func,
+                constant_liar=True,
             )
 
         return self._sampler  # No update happens to self._sampler.
diff --git a/package/samplers/auto_sampler/tests/test_auto_sampler.py b/package/samplers/auto_sampler/tests/test_auto_sampler.py
index b5432636..130473cc 100644
--- a/package/samplers/auto_sampler/tests/test_auto_sampler.py
+++ b/package/samplers/auto_sampler/tests/test_auto_sampler.py
@@ -14,6 +14,11 @@
 parametrize_constraints = pytest.mark.parametrize("use_constraint", [True, False])
 
 
+def objective_1d(trial: optuna.Trial) -> float:
+    x = trial.suggest_float("x", -5, 5)
+    return x**2
+
+
 def objective(trial: optuna.Trial) -> float:
     x = trial.suggest_float("x", -5, 5)
     y = trial.suggest_int("y", -5, 5)
@@ -105,6 +110,21 @@ def test_choose_cmaes() -> None:
     ] * n_trials_of_cmaes == sampler_names
 
 
+def test_choose_tpe_for_1d() -> None:
+    # This test must be performed with a numerical objective function.
+    # For 1d problems, TPESampler will be chosen instead of CmaEsSampler.
+    n_trials_of_tpe = 100
+    n_trials_before_tpe = 20
+    auto_sampler = AutoSampler()
+    auto_sampler._N_COMPLETE_TRIALS_FOR_CMAES = n_trials_before_tpe
+    study = optuna.create_study(sampler=auto_sampler)
+    study.optimize(objective_1d, n_trials=n_trials_of_tpe + n_trials_before_tpe)
+    sampler_names = _get_used_sampler_names(study)
+    assert ["RandomSampler"] + ["GPSampler"] * (n_trials_before_tpe - 1) + [
+        "TPESampler"
+    ] * n_trials_of_tpe == sampler_names
+
+
 def test_choose_tpe_in_single_with_constraints() -> None:
     n_trials = 30
     auto_sampler = AutoSampler(constraints_func=constraints_func)