From a2e33f98effc5f0f2bb69008f7be65fba6c06762 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 14 Dec 2023 01:56:05 -0800 Subject: [PATCH 1/5] Add API specifications for returning the `k` largest elements Ref: https://github.com/data-apis/array-api/issues/629 Ref: https://github.com/numpy/numpy/pull/19117 Ref: https://github.com/numpy/numpy/issues/15128 Ref: https://mail.python.org/archives/list/numpy-discussion@python.org/thread/F4P5UVTAKRJJ3OORI6UOWFSUEE5CNTSC/#PELUDW5ACUBHBNK5IVGWIWTQHBM2HXUP --- .../API_specification/searching_functions.rst | 3 + .../_draft/searching_functions.py | 132 +++++++++++++++++- 2 files changed, 133 insertions(+), 2 deletions(-) diff --git a/spec/draft/API_specification/searching_functions.rst b/spec/draft/API_specification/searching_functions.rst index 01ab4e82a..4cdde193e 100644 --- a/spec/draft/API_specification/searching_functions.rst +++ b/spec/draft/API_specification/searching_functions.rst @@ -23,4 +23,7 @@ Objects in API argmax argmin nonzero + top_k + top_k_indices + top_k_values where diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index e586a7656..e869e5b3d 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -1,7 +1,15 @@ -__all__ = ["argmax", "argmin", "nonzero", "where"] +__all__ = [ + "argmax", + "argmin", + "nonzero", + "top_k", + "top_k_values", + "top_k_indices", + "where", +] -from ._types import Optional, Tuple, array +from ._types import Optional, Literal, Tuple, array def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: @@ -87,6 +95,126 @@ def nonzero(x: array, /) -> Tuple[array, ...]: """ +def top_k( + x: array, + k: int, + /, + *, + axis: Optional[int] = None, + mode: Literal["largest", "smallest"] = "largest", +) -> Tuple[array, array]: + """ + Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. + + Parameters + ---------- + x: array + input array. Should have a real-valued data type. + k: int + number of elements to find. Must be a positive integer value. + axis: Optional[int] + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. + mode: Literal['largest', 'smallest'] + search mode. Must be one of the following modes: + + - ``'largest'``: return the ``k`` largest elements. + - ``'smallest'``: return the ``k`` smallest elements. + + Returns + ------- + out: Tuple[array, array] + a namedtuple ``(values, indices)`` whose + + - first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. + - second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``. + + Notes + ----- + + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements. + - The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). + """ + + +def top_k_indices( + x: array, + k: int, + /, + *, + axis: Optional[int] = None, + mode: Literal["largest", "smallest"] = "largest", +) -> array: + """ + Returns the indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. + + Parameters + ---------- + x: array + input array. Should have a real-valued data type. + k: int + number of elements to find. Must be a positive integer value. + axis: Optional[int] + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. + mode: Literal['largest', 'smallest'] + search mode. Must be one of the following modes: + + - ``'largest'``: return the indices of the ``k`` largest elements. + - ``'smallest'``: return the indices of the ``k`` smallest elements. + + Returns + ------- + out: array + an array containing indices corresponding to the ``k`` largest (or smallest) elements of ``x``. The array must have the default array index data type. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)`` and contain the indices of a flattened ``x``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. + + Notes + ----- + + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices. + - The order of the returned indices is left unspecified and thus implementation-dependent. Conforming implementations may return indices corresponding to sorted or unsorted values. + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). + """ + + +def top_k_values( + x: array, + k: int, + /, + *, + axis: Optional[int] = None, + mode: Literal["largest", "smallest"] = "largest", +) -> array: + """ + Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. + + Parameters + ---------- + x: array + input array. Should have a real-valued data type. + k: int + number of elements to find. Must be a positive integer value. + axis: Optional[int] + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. + mode: Literal['largest', 'smallest'] + search mode. Must be one of the following modes: + + - ``'largest'``: return the indices of the ``k`` largest elements. + - ``'smallest'``: return the indices of the ``k`` smallest elements. + + Returns + ------- + out: array + an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. + + Notes + ----- + + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices. + - The order of the returned values is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). + """ + + def where(condition: array, x1: array, x2: array, /) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. From 30900ebbc4ad2c3821bcfbbf30994c64d6e14299 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 14 Dec 2023 03:13:31 -0800 Subject: [PATCH 2/5] Document default value --- src/array_api_stubs/_draft/searching_functions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index e869e5b3d..7b5b605bb 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -120,6 +120,8 @@ def top_k( - ``'largest'``: return the ``k`` largest elements. - ``'smallest'``: return the ``k`` smallest elements. + Default: ``'largest'``. + Returns ------- out: Tuple[array, array] @@ -162,6 +164,8 @@ def top_k_indices( - ``'largest'``: return the indices of the ``k`` largest elements. - ``'smallest'``: return the indices of the ``k`` smallest elements. + Default: ``'largest'``. + Returns ------- out: array @@ -201,6 +205,8 @@ def top_k_values( - ``'largest'``: return the indices of the ``k`` largest elements. - ``'smallest'``: return the indices of the ``k`` smallest elements. + Default: ``'largest'``. + Returns ------- out: array From 76873d83c57041c9a1d2e0b8b7aabb80c014cdbd Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 25 Jan 2024 03:42:09 -0800 Subject: [PATCH 3/5] Remove whitespace --- src/array_api_stubs/_draft/searching_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 58157f20d..62afd235a 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -257,7 +257,7 @@ def top_k_values( - ``'smallest'``: return the indices of the ``k`` smallest elements. Default: ``'largest'``. - + Returns ------- out: array @@ -271,7 +271,7 @@ def top_k_values( - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ - + def where(condition: array, x1: array, x2: array, /) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. From 07e62e9f5c19012ccc034946d6e12166929422f1 Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 12 Dec 2024 02:16:25 -0800 Subject: [PATCH 4/5] docs: update description Co-authored-by: ndgrigorian <46709016+ndgrigorian@users.noreply.github.com> --- src/array_api_stubs/_draft/searching_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 62afd235a..69c0f195b 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -155,7 +155,7 @@ def top_k( mode: Literal["largest", "smallest"] = "largest", ) -> Tuple[array, array]: """ - Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. + Returns the values and indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. Parameters ---------- From 96461fc89c90d1ea18720dea26516f9367d94341 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 12 Dec 2024 02:54:25 -0800 Subject: [PATCH 5/5] refactor: remove specialized APIs --- .../API_specification/searching_functions.rst | 2 - .../_draft/searching_functions.py | 84 ------------------- 2 files changed, 86 deletions(-) diff --git a/spec/draft/API_specification/searching_functions.rst b/spec/draft/API_specification/searching_functions.rst index 4cdde193e..9579c05f5 100644 --- a/spec/draft/API_specification/searching_functions.rst +++ b/spec/draft/API_specification/searching_functions.rst @@ -24,6 +24,4 @@ Objects in API argmin nonzero top_k - top_k_indices - top_k_values where diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 7b5b605bb..d546c4f28 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -3,8 +3,6 @@ "argmin", "nonzero", "top_k", - "top_k_values", - "top_k_indices", "where", ] @@ -139,88 +137,6 @@ def top_k( """ -def top_k_indices( - x: array, - k: int, - /, - *, - axis: Optional[int] = None, - mode: Literal["largest", "smallest"] = "largest", -) -> array: - """ - Returns the indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. - - Parameters - ---------- - x: array - input array. Should have a real-valued data type. - k: int - number of elements to find. Must be a positive integer value. - axis: Optional[int] - axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. - mode: Literal['largest', 'smallest'] - search mode. Must be one of the following modes: - - - ``'largest'``: return the indices of the ``k`` largest elements. - - ``'smallest'``: return the indices of the ``k`` smallest elements. - - Default: ``'largest'``. - - Returns - ------- - out: array - an array containing indices corresponding to the ``k`` largest (or smallest) elements of ``x``. The array must have the default array index data type. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)`` and contain the indices of a flattened ``x``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. - - Notes - ----- - - - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices. - - The order of the returned indices is left unspecified and thus implementation-dependent. Conforming implementations may return indices corresponding to sorted or unsorted values. - - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - """ - - -def top_k_values( - x: array, - k: int, - /, - *, - axis: Optional[int] = None, - mode: Literal["largest", "smallest"] = "largest", -) -> array: - """ - Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. - - Parameters - ---------- - x: array - input array. Should have a real-valued data type. - k: int - number of elements to find. Must be a positive integer value. - axis: Optional[int] - axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. - mode: Literal['largest', 'smallest'] - search mode. Must be one of the following modes: - - - ``'largest'``: return the indices of the ``k`` largest elements. - - ``'smallest'``: return the indices of the ``k`` smallest elements. - - Default: ``'largest'``. - - Returns - ------- - out: array - an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. - - Notes - ----- - - - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices. - - The order of the returned values is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. - - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - """ - - def where(condition: array, x1: array, x2: array, /) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.