diff --git a/optuna/importance/_fanova/_tree.py b/optuna/importance/_fanova/_tree.py index 8a0f2693a5..ae22d51551 100644 --- a/optuna/importance/_fanova/_tree.py +++ b/optuna/importance/_fanova/_tree.py @@ -265,7 +265,9 @@ def _get_node_children(self, node_index: int) -> Tuple[int, int]: @lru_cache(maxsize=None) def _get_node_value(self, node_index: int) -> float: - return float(self._tree.value[node_index]) + # self._tree.value: sklearn.tree._tree.Tree.value has + # the shape (node_count, n_outputs, max_n_classes) + return float(self._tree.value[node_index].reshape(-1)[0]) @lru_cache(maxsize=None) def _get_node_split_threshold(self, node_index: int) -> float: diff --git a/optuna/terminator/improvement/gp/botorch.py b/optuna/terminator/improvement/gp/botorch.py index 0f43573fd3..c7e32f5635 100644 --- a/optuna/terminator/improvement/gp/botorch.py +++ b/optuna/terminator/improvement/gp/botorch.py @@ -83,7 +83,7 @@ def predict_mean_std( variance = posterior.variance std = variance.sqrt() - return mean.detach().numpy(), std.detach().numpy() + return mean.detach().numpy().squeeze(-1), std.detach().numpy().squeeze(-1) def _convert_trials_to_tensors(trials: list[FrozenTrial]) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/tests/importance_tests/fanova_tests/test_tree.py b/tests/importance_tests/fanova_tests/test_tree.py index d828f4c6fc..2a0e02954f 100644 --- a/tests/importance_tests/fanova_tests/test_tree.py +++ b/tests/importance_tests/fanova_tests/test_tree.py @@ -18,7 +18,8 @@ def tree() -> _FanovaTree: sklearn_tree.feature = [1, 2, -1, -1, -1] sklearn_tree.children_left = [1, 2, -1, -1, -1] sklearn_tree.children_right = [4, 3, -1, -1, -1] - sklearn_tree.value = [-1.0, -1.0, 0.1, 0.2, 0.5] + # value has the shape (node_count, n_output, max_n_classes) + sklearn_tree.value = numpy.array([[[-1.0]], [[-1.0]], [[0.1]], [[0.2]], [[0.5]]]) sklearn_tree.threshold = [0.5, 1.5, -1.0, -1.0, -1.0] search_spaces = numpy.array([[0.0, 1.0], [0.0, 1.0], [0.0, 2.0]]) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index baf57131de..021edcc6a3 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -196,9 +196,7 @@ def test_check_distribution_compatibility() -> None: ) -@pytest.mark.parametrize( - "value", (0, 1, 4, 10, 11, 1.1, "1", "1.1", "-1.0", True, False, np.ones(1), np.array([1.1])) -) +@pytest.mark.parametrize("value", (0, 1, 4, 10, 11, 1.1, "1", "1.1", "-1.0", True, False)) def test_int_internal_representation(value: Any) -> None: i = distributions.IntDistribution(low=1, high=10) @@ -231,7 +229,7 @@ def test_int_internal_representation_error(value: Any, kwargs: Dict[str, Any]) - @pytest.mark.parametrize( "value", - (1.99, 2.0, 4.5, 7, 7.1, 1, "1", "1.1", "-1.0", True, False, np.ones(1), np.array([1.1])), + (1.99, 2.0, 4.5, 7, 7.1, 1, "1", "1.1", "-1.0", True, False), ) def test_float_internal_representation(value: Any) -> None: f = distributions.FloatDistribution(low=2.0, high=7.0)