diff --git a/optuna/samplers/_tpe/sampler.py b/optuna/samplers/_tpe/sampler.py index 43e478e1e9..fbb27ac39c 100644 --- a/optuna/samplers/_tpe/sampler.py +++ b/optuna/samplers/_tpe/sampler.py @@ -587,6 +587,7 @@ def _split_trials( ) -> tuple[list[FrozenTrial], list[FrozenTrial]]: complete_trials = [] pruned_trials = [] + running_trials = [] infeasible_trials = [] for trial in trials: @@ -596,33 +597,30 @@ def _split_trials( complete_trials.append(trial) elif trial.state == TrialState.PRUNED: pruned_trials.append(trial) + elif trial.state == TrialState.RUNNING: + running_trials.append(trial) else: - assert trial.state == TrialState.RUNNING + assert False # We divide data into below and above. - if len(complete_trials) >= n_below: - below_trials = _split_complete_trials(complete_trials, study, n_below) - elif len(complete_trials) + len(pruned_trials) >= n_below: - below_pruned_trials = _split_pruned_trials( - pruned_trials, study, n_below - len(complete_trials) - ) - below_trials = complete_trials + below_pruned_trials - else: - below_infeasible_trials = _split_infeasible_trials( - infeasible_trials, n_below - len(complete_trials) - len(pruned_trials) - ) - below_trials = complete_trials + pruned_trials + below_infeasible_trials - + below_complete, above_complete = _split_complete_trials(complete_trials, study, n_below) + n_below -= len(below_complete) + below_pruned, above_pruned = _split_pruned_trials(pruned_trials, study, n_below) + n_below -= len(below_pruned) + below_infeasible, above_infeasible = _split_infeasible_trials(infeasible_trials, n_below) + + below_trials = below_complete + below_pruned + below_infeasible + above_trials = above_complete + above_pruned + above_infeasible + running_trials below_trials.sort(key=lambda trial: trial.number) - below_trial_numbers = set(trial.number for trial in below_trials) - above_trials = [trial for trial in trials if trial.number not in below_trial_numbers] + above_trials.sort(key=lambda trial: trial.number) return below_trials, above_trials def _split_complete_trials( trials: Sequence[FrozenTrial], study: Study, n_below: int -) -> list[FrozenTrial]: +) -> tuple[list[FrozenTrial], list[FrozenTrial]]: + n_below = min(n_below, len(trials)) if len(study.directions) <= 1: return _split_complete_trials_single_objective(trials, study, n_below) else: @@ -633,20 +631,21 @@ def _split_complete_trials_single_objective( trials: Sequence[FrozenTrial], study: Study, n_below: int, -) -> list[FrozenTrial]: +) -> tuple[list[FrozenTrial], list[FrozenTrial]]: if study.direction == StudyDirection.MINIMIZE: - return sorted(trials, key=lambda trial: cast(float, trial.value))[:n_below] + sorted_trials = sorted(trials, key=lambda trial: cast(float, trial.value)) else: - return sorted(trials, key=lambda trial: cast(float, trial.value), reverse=True)[:n_below] + sorted_trials = sorted(trials, key=lambda trial: cast(float, trial.value), reverse=True) + return sorted_trials[:n_below], sorted_trials[n_below:] def _split_complete_trials_multi_objective( trials: Sequence[FrozenTrial], study: Study, n_below: int, -) -> list[FrozenTrial]: +) -> tuple[list[FrozenTrial], list[FrozenTrial]]: if n_below == 0: - return [] + return [], [] lvals = np.asarray([trial.values for trial in trials]) for i, direction in enumerate(study.directions): @@ -680,7 +679,14 @@ def _split_complete_trials_multi_objective( selected_indices = _solve_hssp(rank_i_lvals, rank_i_indices, subset_size, reference_point) indices_below[last_idx:] = selected_indices - return [trials[index] for index in indices_below] + below_trials = [] + above_trials = [] + for index in range(len(trials)): + if index in indices_below: + below_trials.append(trials[index]) + else: + above_trials.append(trials[index]) + return below_trials, above_trials def _get_pruned_trial_score(trial: FrozenTrial, study: Study) -> tuple[float, float]: @@ -700,8 +706,10 @@ def _split_pruned_trials( trials: Sequence[FrozenTrial], study: Study, n_below: int, -) -> list[FrozenTrial]: - return sorted(trials, key=lambda trial: _get_pruned_trial_score(trial, study))[:n_below] +) -> tuple[list[FrozenTrial], list[FrozenTrial]]: + n_below = min(n_below, len(trials)) + sorted_trials = sorted(trials, key=lambda trial: _get_pruned_trial_score(trial, study)) + return sorted_trials[:n_below], sorted_trials[n_below:] def _get_infeasible_trial_score(trial: FrozenTrial) -> float: @@ -717,8 +725,12 @@ def _get_infeasible_trial_score(trial: FrozenTrial) -> float: return sum(v for v in constraint if v > 0) -def _split_infeasible_trials(trials: Sequence[FrozenTrial], n_below: int) -> list[FrozenTrial]: - return sorted(trials, key=_get_infeasible_trial_score)[:n_below] +def _split_infeasible_trials( + trials: Sequence[FrozenTrial], n_below: int +) -> tuple[list[FrozenTrial], list[FrozenTrial]]: + n_below = min(n_below, len(trials)) + sorted_trials = sorted(trials, key=_get_infeasible_trial_score) + return sorted_trials[:n_below], sorted_trials[n_below:] def _calculate_weights_below_for_multi_objective( diff --git a/tests/samplers_tests/tpe_tests/test_multi_objective_sampler.py b/tests/samplers_tests/tpe_tests/test_multi_objective_sampler.py index 0d3ea7679a..5c8d97fb34 100644 --- a/tests/samplers_tests/tpe_tests/test_multi_objective_sampler.py +++ b/tests/samplers_tests/tpe_tests/test_multi_objective_sampler.py @@ -324,17 +324,18 @@ def test_split_complete_trials_multi_objective(direction0: str, direction1: str) ) ) - below_trials = _tpe.sampler._split_complete_trials_multi_objective( + below_trials, above_trials = _tpe.sampler._split_complete_trials_multi_objective( study.trials, study, 2, ) assert [trial.number for trial in below_trials] == [0, 3] + assert [trial.number for trial in above_trials] == [1, 2] def test_split_complete_trials_multi_objective_empty() -> None: study = optuna.create_study(directions=("minimize", "minimize")) - _tpe.sampler._split_complete_trials_multi_objective([], study, 0) == [] + _tpe.sampler._split_complete_trials_multi_objective([], study, 0) == ([], []) def test_calculate_nondomination_rank() -> None: diff --git a/tests/samplers_tests/tpe_tests/test_sampler.py b/tests/samplers_tests/tpe_tests/test_sampler.py index 76d72d0c4a..5973f22721 100644 --- a/tests/samplers_tests/tpe_tests/test_sampler.py +++ b/tests/samplers_tests/tpe_tests/test_sampler.py @@ -832,17 +832,18 @@ def test_split_complete_trials_single_objective(direction: str) -> None: ) for n_below in range(len(study.trials) + 1): - below_trials = _tpe.sampler._split_complete_trials_single_objective( + below_trials, above_trials = _tpe.sampler._split_complete_trials_single_objective( study.trials, study, n_below, ) assert [trial.number for trial in below_trials] == list(range(n_below)) + assert [trial.number for trial in above_trials] == list(range(n_below, len(study.trials))) def test_split_complete_trials_single_objective_empty() -> None: study = optuna.create_study() - _tpe.sampler._split_complete_trials_single_objective([], study, 0) == [] + _tpe.sampler._split_complete_trials_single_objective([], study, 0) == ([], []) @pytest.mark.parametrize("direction", ["minimize", "maximize"]) @@ -869,17 +870,18 @@ def test_split_pruned_trials(direction: str) -> None: ) for n_below in range(len(study.trials) + 1): - below_trials = _tpe.sampler._split_pruned_trials( + below_trials, above_trials = _tpe.sampler._split_pruned_trials( study.trials, study, n_below, ) assert [trial.number for trial in below_trials] == list(range(n_below)) + assert [trial.number for trial in above_trials] == list(range(n_below, len(study.trials))) def test_split_pruned_trials_empty() -> None: study = optuna.create_study() - _tpe.sampler._split_pruned_trials([], study, 0) == [] + _tpe.sampler._split_pruned_trials([], study, 0) == ([], []) @pytest.mark.parametrize("direction", ["minimize", "maximize"]) @@ -898,12 +900,13 @@ def test_split_infeasible_trials(direction: str) -> None: ) for n_below in range(len(study.trials) + 1): - below_trials = _tpe.sampler._split_infeasible_trials(study.trials, n_below) + below_trials, above_trials = _tpe.sampler._split_infeasible_trials(study.trials, n_below) assert [trial.number for trial in below_trials] == list(range(n_below)) + assert [trial.number for trial in above_trials] == list(range(n_below, len(study.trials))) def test_split_infeasible_trials_empty() -> None: - _tpe.sampler._split_infeasible_trials([], 0) == [] + _tpe.sampler._split_infeasible_trials([], 0) == ([], []) def frozen_trial_factory(