diff --git a/optuna/_hypervolume/hssp.py b/optuna/_hypervolume/hssp.py index 1b63552712..00e417d95f 100644 --- a/optuna/_hypervolume/hssp.py +++ b/optuna/_hypervolume/hssp.py @@ -44,10 +44,12 @@ def _solve_hssp_on_unique_loss_vals( subset_size: int, reference_point: np.ndarray, ) -> np.ndarray: - assert not np.any(reference_point - rank_i_loss_vals <= 0) + if not np.isfinite(reference_point).all(): + return rank_i_indices[:subset_size] + diff_of_loss_vals_and_ref_point = reference_point - rank_i_loss_vals assert subset_size <= rank_i_indices.size n_objectives = reference_point.size - contribs = np.prod(reference_point - rank_i_loss_vals, axis=-1) + contribs = np.prod(diff_of_loss_vals_and_ref_point, axis=-1) selected_indices = np.zeros(subset_size, dtype=int) selected_vecs = np.empty((subset_size, n_objectives)) indices = np.arange(rank_i_loss_vals.shape[0], dtype=int) diff --git a/optuna/_hypervolume/utils.py b/optuna/_hypervolume/utils.py index bb9a4275e2..223e5aed3e 100644 --- a/optuna/_hypervolume/utils.py +++ b/optuna/_hypervolume/utils.py @@ -14,6 +14,8 @@ def _compute_2d(solution_set: np.ndarray, reference_point: np.ndarray) -> float: The reference point to compute the hypervolume. """ assert solution_set.shape[1] == 2 and reference_point.shape[0] == 2 + if not np.isfinite(reference_point).all(): + return float("inf") # Ascending order in the 1st objective, and descending order in the 2nd objective. sorted_solution_set = solution_set[np.lexsort((-solution_set[:, 1], solution_set[:, 0]))] diff --git a/optuna/_hypervolume/wfg.py b/optuna/_hypervolume/wfg.py index 206130ab18..75fb72ceb6 100644 --- a/optuna/_hypervolume/wfg.py +++ b/optuna/_hypervolume/wfg.py @@ -20,6 +20,8 @@ def __init__(self) -> None: self._reference_point: np.ndarray | None = None def _compute(self, solution_set: np.ndarray, reference_point: np.ndarray) -> float: + if not np.isfinite(reference_point).all(): + return float("inf") self._reference_point = reference_point.astype(np.float64) if self._reference_point.shape[0] == 2: return _compute_2d(solution_set, self._reference_point) diff --git a/tests/hypervolume_tests/test_hssp.py b/tests/hypervolume_tests/test_hssp.py index 454d4385e7..86be7b9f5f 100644 --- a/tests/hypervolume_tests/test_hssp.py +++ b/tests/hypervolume_tests/test_hssp.py @@ -1,4 +1,5 @@ import itertools +import math from typing import Tuple import numpy as np @@ -11,11 +12,14 @@ def _compute_hssp_truth_and_approx(test_case: np.ndarray, subset_size: int) -> T r = 1.1 * np.max(test_case, axis=0) truth = 0.0 for subset in itertools.permutations(test_case, subset_size): - truth = max(truth, optuna._hypervolume.WFG().compute(np.asarray(subset), r)) + hv = optuna._hypervolume.WFG().compute(np.asarray(subset), r) + assert not math.isnan(hv) + truth = max(truth, hv) indices = optuna._hypervolume.hssp._solve_hssp( test_case, np.arange(len(test_case)), subset_size, r ) approx = optuna._hypervolume.WFG().compute(test_case[indices], r) + assert not math.isnan(approx) return truth, approx @@ -30,24 +34,17 @@ def test_solve_hssp(dim: int) -> None: assert approx / truth > 0.6321 # 1 - 1/e -@pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_solve_hssp_infinite_loss() -> None: rng = np.random.RandomState(128) subset_size = 4 - test_case = rng.rand(9, 2) - test_case[-1].fill(float("inf")) - truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size) - assert np.isinf(truth) - assert np.isinf(approx) - - test_case = rng.rand(9, 3) - test_case[-1].fill(float("inf")) - truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size) - assert truth == 0 - assert np.isnan(approx) - for dim in range(2, 4): + test_case = rng.rand(9, dim) + test_case[-1].fill(float("inf")) + truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size) + assert np.isinf(truth) + assert np.isinf(approx) + test_case = rng.rand(9, dim) test_case[-1].fill(-float("inf")) truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size) @@ -55,7 +52,6 @@ def test_solve_hssp_infinite_loss() -> None: assert np.isinf(approx) -@pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_solve_hssp_duplicated_infinite_loss() -> None: test_case = np.array([[np.inf, 0, 0], [np.inf, 0, 0], [0, np.inf, 0], [0, 0, np.inf]]) r = np.full(3, np.inf)