diff --git a/spec/draft/API_specification/searching_functions.rst b/spec/draft/API_specification/searching_functions.rst index 1a584f158..494b72cf7 100644 --- a/spec/draft/API_specification/searching_functions.rst +++ b/spec/draft/API_specification/searching_functions.rst @@ -25,4 +25,5 @@ Objects in API count_nonzero nonzero searchsorted + top_k where diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 4eee3173b..cef7c05be 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -1,7 +1,15 @@ -__all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"] +__all__ = [ + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "top_k", + "where", +] -from ._types import Optional, Tuple, Literal, Union, array +from ._types import Optional, Literal, Tuple, Union, array def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: @@ -168,6 +176,50 @@ def searchsorted( """ +def top_k( + x: array, + k: int, + /, + *, + axis: Optional[int] = None, + mode: Literal["largest", "smallest"] = "largest", +) -> Tuple[array, array]: + """ + Returns the values and indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. + + Parameters + ---------- + x: array + input array. Should have a real-valued data type. + k: int + number of elements to find. Must be a positive integer value. + axis: Optional[int] + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. + mode: Literal['largest', 'smallest'] + search mode. Must be one of the following modes: + + - ``'largest'``: return the ``k`` largest elements. + - ``'smallest'``: return the ``k`` smallest elements. + + Default: ``'largest'``. + + Returns + ------- + out: Tuple[array, array] + a namedtuple ``(values, indices)`` whose + + - first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. + - second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``. + + Notes + ----- + + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements. + - The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). + """ + + def where(condition: array, x1: array, x2: array, /) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.