Skip to content

Commit

Permalink
Revert some changes for brivity
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Dec 10, 2024
1 parent d4c7dc8 commit d71d0c3
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions package/samplers/hebo/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ def __init__(
self._hebo = None
self._intersection_search_space = IntersectionSearchSpace()
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
self._is_independent_sample_necessary = False
self._constant_liar = constant_liar
self._rng = np.random.default_rng(seed)

@staticmethod
def _suggest_and_transform_to_dict(
hebo: HEBO, search_space: dict[str, BaseDistribution]
) -> dict[str, Any]:
) -> dict[str, float]:
params = {}
for name, row in hebo.suggest().items():
if name not in search_space:
Expand All @@ -113,12 +114,12 @@ def _suggest_and_transform_to_dict(

def _sample_relative_define_and_run(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> dict[str, Any]:
) -> dict[str, float]:
return self._suggest_and_transform_to_dict(self._hebo, search_space)

def _sample_relative_stateless(
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> dict[str, Any]:
) -> dict[str, float]:
if self._constant_liar:
target_states = [TrialState.COMPLETE, TrialState.RUNNING]
else:
Expand All @@ -132,7 +133,10 @@ def _sample_relative_stateless(
# This sampler does not call `hebo.suggest()` here because
# Optuna needs to know search space by running the first trial in Define-by-Run.
return {}

self._is_independent_sample_necessary = True
return {}
else:
self._is_independent_sample_necessary = False
trials = [t for t in trials if set(search_space.keys()) <= set(t.params.keys())]

# Assume that the back-end HEBO implementation aims to minimize.
Expand Down Expand Up @@ -224,9 +228,7 @@ def sample_independent(
param_name: str,
param_distribution: BaseDistribution,
) -> Any:
states = (TrialState.COMPLETE, TrialState.RUNNING)
trials = study._get_trials(deepcopy=False, states=states, use_cache=True)
if any(param_name in trial.params for trial in trials):
if not self._is_independent_sample_necessary:
warnings.warn(
"`HEBOSampler` falls back to `RandomSampler` due to dynamic search space."
)
Expand Down

0 comments on commit d71d0c3

Please sign in to comment.