diff --git a/recipes/001_first.py b/recipes/001_first.py index 4c25ea47..91fbaa4b 100644 --- a/recipes/001_first.py +++ b/recipes/001_first.py @@ -43,15 +43,13 @@ # `force_reload=True` argument forces downloading the sampler from the registry. # If we set `force_reload` to `False`, we use the cached data in our local storage if available. -SimpleBaseSampler = optunahub.load_module("samplers/simple").SimpleBaseSampler +SimpleSampler = optunahub.load_module("samplers/simple").SimpleSampler -class MySampler(SimpleBaseSampler): # type: ignore +class MySampler(SimpleSampler): # type: ignore # By default, search space will be estimated automatically like Optuna's built-in samplers. - # You can fix the search spacd by `search_space` argument of `SimpleBaseSampler` class. - def __init__( - self, search_space: dict[str, optuna.distributions.BaseDistribution] | None = None - ) -> None: + # You can fix the search spacd by `search_space` argument of `SimpleSampler` class. + def __init__(self, search_space: dict[str, optuna.distributions.BaseDistribution]) -> None: super().__init__(search_space) self._rng = np.random.RandomState() @@ -69,10 +67,6 @@ def sample_relative( # `search_space` argument must be identical to `search_space` argument input to `__init__` method. # This method is automatically invoked by Optuna and `SimpleBaseSampler`. - # If search space is empty, all parameter values are sampled randomly by SimpleBaseSampler. - if search_space == {}: - return {} - params = {} # type: dict[str, Any] for n, d in search_space.items(): if isinstance(d, optuna.distributions.FloatDistribution): @@ -98,7 +92,7 @@ def objective(trial: optuna.trial.Trial) -> float: ################################################################################################### # This sampler can be used in the same way as other Optuna samplers. # In the following example, we create a study and optimize it using `MySampler` class. -sampler = MySampler() +sampler = MySampler({"x": optuna.distributions.FloatDistribution(-10, 10)}) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=100)