From 899476502a3e81a5408583b0161f455251ba8470 Mon Sep 17 00:00:00 2001 From: Jules Date: Mon, 24 Jun 2024 17:48:21 +0800 Subject: [PATCH] WIP: top_k tests The purpose of this PR is to continue several threads of discussion regarding `top_k`. This follows roughly the specifications of `top_k` in data-apis/array-api#722, with slight modifications to the API: ```py def topk( x: array, k: int, /, axis: Optional[int] = None, *, largest: bool = True, ) -> Tuple[array, array]: ... ``` Modifications: - `mode: Literal["largest", "smallest"]` is replaced with `largest: bool` - `axis` is no longer a kw-only arg. This makes `torch.topk` slightly more compatible. The tests implemented here follows the proposed `top_k` implementation at numpy/numpy#26666. --- array_api_tests/test_searching_functions.py | 100 ++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index a12e9d52..a159c4a4 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -3,6 +3,7 @@ import pytest from hypothesis import given, note from hypothesis import strategies as st +from hypothesis.control import assume from . import _array_module as xp from . import dtype_helpers as dh @@ -203,3 +204,102 @@ def test_searchsorted(data): expected=xp.__array_namespace_info__().default_dtypes()["indexing"], ) # TODO: shapes and values testing + + +@pytest.mark.unvectorized +# TODO: Test with signed zeros and NaNs (and ignore them somehow) +@given( + x=hh.arrays( + dtype=hh.real_dtypes, + shape=hh.shapes(min_dims=1, min_side=1), + elements={"allow_nan": False}, + ), + data=st.data() +) +def test_top_k(x, data): + + if dh.is_float_dtype(x.dtype): + assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) + + axis = data.draw( + st.integers(-x.ndim, x.ndim - 1), label='axis') + largest = data.draw(st.booleans(), label='largest') + if axis is None: + k = data.draw(st.integers(1, math.prod(x.shape))) + else: + k = data.draw(st.integers(1, x.shape[axis])) + + kw = dict( + x=x, + k=k, + axis=axis, + largest=largest, + ) + + (out_values, out_indices) = xp.top_k(x, k, axis, largest=largest) + if axis is None: + x = xp.reshape(x, (-1,)) + axis = 0 + + ph.assert_dtype("top_k", in_dtype=x.dtype, out_dtype=out_values.dtype) + ph.assert_dtype( + "top_k", + in_dtype=x.dtype, + out_dtype=out_indices.dtype, + expected=dh.default_int + ) + axes, = sh.normalise_axis(axis, x.ndim) + for arr in [out_values, out_indices]: + ph.assert_shape( + "top_k", + out_shape=arr.shape, + expected=x.shape[:axes] + (k,) + x.shape[axes + 1:], + kw=kw + ) + + scalar_type = dh.get_scalar_type(x.dtype) + + for indices in sh.axes_ndindex(x.shape, (axes,)): + + # Test if the values indexed by out_indices corresponds to + # the correct top_k values. + elements = [scalar_type(x[idx]) for idx in indices] + size = len(elements) + correct_order = sorted( + range(size), + key=elements.__getitem__, + reverse=largest + ) + correct_order = correct_order[:k] + test_order = [out_indices[idx] for idx in indices[:k]] + # Sort because top_k does not necessarily return the values in + # sorted order. + test_sorted_order = sorted( + test_order, + key=elements.__getitem__, + reverse=largest + ) + + for y_o, x_o in zip(correct_order, test_sorted_order): + y_idx = indices[y_o] + x_idx = indices[x_o] + ph.assert_0d_equals( + "top_k", + x_repr=f"x[{x_idx}]", + x_val=x[x_idx], + out_repr=f"x[{y_idx}]", + out_val=x[y_idx], + kw=kw, + ) + + # Test if the values indexed by out_indices corresponds to out_values. + for y_o, x_idx in zip(test_order, indices[:k]): + y_idx = indices[y_o] + ph.assert_0d_equals( + "top_k", + x_repr=f"out_values[{x_idx}]", + x_val=scalar_type(out_values[x_idx]), + out_repr=f"x[{y_idx}]", + out_val=x[y_idx], + kw=kw + )