Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Dec 3, 2024
1 parent d185af3 commit 2c49a24
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions package/samplers/ctpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from optuna.logging import get_logger
from optuna.samplers import TPESampler
from optuna.samplers._tpe.parzen_estimator import _ParzenEstimator
from optuna.samplers._tpe.sampler import _split_trials
from optuna.study import Study
from optuna.study import StudyDirection
from optuna.trial import FrozenTrial
Expand Down Expand Up @@ -45,6 +44,11 @@ def __init__(
use_min_bandwidth_discrete: bool = True,
constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None,
):
if constraints_func is None:
raise ValueError(
f"{self.__class__.__name__} must take constraints_func, but got None."
)

gamma = GammaFunc(strategy=gamma_strategy, beta=gamma_beta)
weights = WeightFunc(strategy=weight_strategy)
super().__init__(
Expand Down Expand Up @@ -136,7 +140,7 @@ def _sample(
)

n_below_feasible = self._gamma(len(trials))
below_trials, above_trials = _split_trials(
below_trials, above_trials = _split_trials_for_ctpe(
study, trials, n_below_feasible, is_feasible=np.all(constraints_vals <= 0, axis=-1)
)
mpes_below.append(
Expand All @@ -145,6 +149,7 @@ def _sample(
mpes_above.append(
self._build_parzen_estimator(study, search_space, above_trials, handle_below=False)
)
quantiles.append(len(below_trials) / len(trials))

_samples_below: dict[str, list[_ParzenEstimator]] = {
param_name: [] for param_name in search_space
Expand Down

0 comments on commit 2c49a24

Please sign in to comment.