Skip to content

Commit

Permalink
Return below and above in each split functions
Browse files Browse the repository at this point in the history
  • Loading branch information
not522 committed Sep 29, 2023
1 parent 25edf5c commit ce93d40
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 35 deletions.
66 changes: 39 additions & 27 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def _split_trials(
) -> tuple[list[FrozenTrial], list[FrozenTrial]]:
complete_trials = []
pruned_trials = []
running_trials = []
infeasible_trials = []

for trial in trials:
Expand All @@ -596,33 +597,30 @@ def _split_trials(
complete_trials.append(trial)
elif trial.state == TrialState.PRUNED:
pruned_trials.append(trial)
elif trial.state == TrialState.RUNNING:
running_trials.append(trial)
else:
assert trial.state == TrialState.RUNNING
assert False

# We divide data into below and above.
if len(complete_trials) >= n_below:
below_trials = _split_complete_trials(complete_trials, study, n_below)
elif len(complete_trials) + len(pruned_trials) >= n_below:
below_pruned_trials = _split_pruned_trials(
pruned_trials, study, n_below - len(complete_trials)
)
below_trials = complete_trials + below_pruned_trials
else:
below_infeasible_trials = _split_infeasible_trials(
infeasible_trials, n_below - len(complete_trials) - len(pruned_trials)
)
below_trials = complete_trials + pruned_trials + below_infeasible_trials

below_complete, above_complete = _split_complete_trials(complete_trials, study, n_below)
n_below -= len(below_complete)
below_pruned, above_pruned = _split_pruned_trials(pruned_trials, study, n_below)
n_below -= len(below_pruned)
below_infeasible, above_infeasible = _split_infeasible_trials(infeasible_trials, n_below)

below_trials = below_complete + below_pruned + below_infeasible
above_trials = above_complete + above_pruned + above_infeasible + running_trials
below_trials.sort(key=lambda trial: trial.number)
below_trial_numbers = set(trial.number for trial in below_trials)
above_trials = [trial for trial in trials if trial.number not in below_trial_numbers]
above_trials.sort(key=lambda trial: trial.number)

return below_trials, above_trials


def _split_complete_trials(
trials: Sequence[FrozenTrial], study: Study, n_below: int
) -> list[FrozenTrial]:
) -> tuple[list[FrozenTrial], list[FrozenTrial]]:
n_below = min(n_below, len(trials))
if len(study.directions) <= 1:
return _split_complete_trials_single_objective(trials, study, n_below)
else:
Expand All @@ -633,20 +631,21 @@ def _split_complete_trials_single_objective(
trials: Sequence[FrozenTrial],
study: Study,
n_below: int,
) -> list[FrozenTrial]:
) -> tuple[list[FrozenTrial], list[FrozenTrial]]:
if study.direction == StudyDirection.MINIMIZE:
return sorted(trials, key=lambda trial: cast(float, trial.value))[:n_below]
sorted_trials = sorted(trials, key=lambda trial: cast(float, trial.value))
else:
return sorted(trials, key=lambda trial: cast(float, trial.value), reverse=True)[:n_below]
sorted_trials = sorted(trials, key=lambda trial: cast(float, trial.value), reverse=True)
return sorted_trials[:n_below], sorted_trials[n_below:]


def _split_complete_trials_multi_objective(
trials: Sequence[FrozenTrial],
study: Study,
n_below: int,
) -> list[FrozenTrial]:
) -> tuple[list[FrozenTrial], list[FrozenTrial]]:
if n_below == 0:
return []
return [], []

lvals = np.asarray([trial.values for trial in trials])
for i, direction in enumerate(study.directions):
Expand Down Expand Up @@ -680,7 +679,14 @@ def _split_complete_trials_multi_objective(
selected_indices = _solve_hssp(rank_i_lvals, rank_i_indices, subset_size, reference_point)
indices_below[last_idx:] = selected_indices

return [trials[index] for index in indices_below]
below_trials = []
above_trials = []
for index in range(len(trials)):
if index in indices_below:
below_trials.append(trials[index])
else:
above_trials.append(trials[index])
return below_trials, above_trials


def _get_pruned_trial_score(trial: FrozenTrial, study: Study) -> tuple[float, float]:
Expand All @@ -700,8 +706,10 @@ def _split_pruned_trials(
trials: Sequence[FrozenTrial],
study: Study,
n_below: int,
) -> list[FrozenTrial]:
return sorted(trials, key=lambda trial: _get_pruned_trial_score(trial, study))[:n_below]
) -> tuple[list[FrozenTrial], list[FrozenTrial]]:
n_below = min(n_below, len(trials))
sorted_trials = sorted(trials, key=lambda trial: _get_pruned_trial_score(trial, study))
return sorted_trials[:n_below], sorted_trials[n_below:]


def _get_infeasible_trial_score(trial: FrozenTrial) -> float:
Expand All @@ -717,8 +725,12 @@ def _get_infeasible_trial_score(trial: FrozenTrial) -> float:
return sum(v for v in constraint if v > 0)


def _split_infeasible_trials(trials: Sequence[FrozenTrial], n_below: int) -> list[FrozenTrial]:
return sorted(trials, key=_get_infeasible_trial_score)[:n_below]
def _split_infeasible_trials(
trials: Sequence[FrozenTrial], n_below: int
) -> tuple[list[FrozenTrial], list[FrozenTrial]]:
n_below = min(n_below, len(trials))
sorted_trials = sorted(trials, key=_get_infeasible_trial_score)
return sorted_trials[:n_below], sorted_trials[n_below:]


def _calculate_weights_below_for_multi_objective(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,18 @@ def test_split_complete_trials_multi_objective(direction0: str, direction1: str)
)
)

below_trials = _tpe.sampler._split_complete_trials_multi_objective(
below_trials, above_trials = _tpe.sampler._split_complete_trials_multi_objective(
study.trials,
study,
2,
)
assert [trial.number for trial in below_trials] == [0, 3]
assert [trial.number for trial in above_trials] == [1, 2]


def test_split_complete_trials_multi_objective_empty() -> None:
study = optuna.create_study(directions=("minimize", "minimize"))
_tpe.sampler._split_complete_trials_multi_objective([], study, 0) == []
_tpe.sampler._split_complete_trials_multi_objective([], study, 0) == ([], [])


def test_calculate_nondomination_rank() -> None:
Expand Down
15 changes: 9 additions & 6 deletions tests/samplers_tests/tpe_tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,17 +832,18 @@ def test_split_complete_trials_single_objective(direction: str) -> None:
)

for n_below in range(len(study.trials) + 1):
below_trials = _tpe.sampler._split_complete_trials_single_objective(
below_trials, above_trials = _tpe.sampler._split_complete_trials_single_objective(
study.trials,
study,
n_below,
)
assert [trial.number for trial in below_trials] == list(range(n_below))
assert [trial.number for trial in above_trials] == list(range(n_below, len(study.trials)))


def test_split_complete_trials_single_objective_empty() -> None:
study = optuna.create_study()
_tpe.sampler._split_complete_trials_single_objective([], study, 0) == []
_tpe.sampler._split_complete_trials_single_objective([], study, 0) == ([], [])


@pytest.mark.parametrize("direction", ["minimize", "maximize"])
Expand All @@ -869,17 +870,18 @@ def test_split_pruned_trials(direction: str) -> None:
)

for n_below in range(len(study.trials) + 1):
below_trials = _tpe.sampler._split_pruned_trials(
below_trials, above_trials = _tpe.sampler._split_pruned_trials(
study.trials,
study,
n_below,
)
assert [trial.number for trial in below_trials] == list(range(n_below))
assert [trial.number for trial in above_trials] == list(range(n_below, len(study.trials)))


def test_split_pruned_trials_empty() -> None:
study = optuna.create_study()
_tpe.sampler._split_pruned_trials([], study, 0) == []
_tpe.sampler._split_pruned_trials([], study, 0) == ([], [])


@pytest.mark.parametrize("direction", ["minimize", "maximize"])
Expand All @@ -898,12 +900,13 @@ def test_split_infeasible_trials(direction: str) -> None:
)

for n_below in range(len(study.trials) + 1):
below_trials = _tpe.sampler._split_infeasible_trials(study.trials, n_below)
below_trials, above_trials = _tpe.sampler._split_infeasible_trials(study.trials, n_below)
assert [trial.number for trial in below_trials] == list(range(n_below))
assert [trial.number for trial in above_trials] == list(range(n_below, len(study.trials)))


def test_split_infeasible_trials_empty() -> None:
_tpe.sampler._split_infeasible_trials([], 0) == []
_tpe.sampler._split_infeasible_trials([], 0) == ([], [])


def frozen_trial_factory(
Expand Down

0 comments on commit ce93d40

Please sign in to comment.