Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API specifications for returning the k largest elements #722

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions spec/draft/API_specification/searching_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ Objects in API
argmin
nonzero
searchsorted
top_k
top_k_indices
top_k_values
where
139 changes: 137 additions & 2 deletions src/array_api_stubs/_draft/searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]
__all__ = [
"argmax",
"argmin",
"nonzero",
"searchsorted",
"top_k",
"top_k_values",
"top_k_indices",
"where",
]


from ._types import Optional, Tuple, Literal, array
from ._types import Optional, Literal, Tuple, array


def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
Expand Down Expand Up @@ -137,6 +146,132 @@ def searchsorted(
"""


def top_k(
x: array,
k: int,
/,
*,
axis: Optional[int] = None,
mode: Literal["largest", "smallest"] = "largest",
) -> Tuple[array, array]:
"""
Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.
Returns the values and indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.

Otherwise it would be identical to the description of top_k_values.

Parameters
----------
x: array
input array. Should have a real-valued data type.
k: int
number of elements to find. Must be a positive integer value.
axis: Optional[int]
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``.
mode: Literal['largest', 'smallest']
search mode. Must be one of the following modes:

- ``'largest'``: return the ``k`` largest elements.
- ``'smallest'``: return the ``k`` smallest elements.

Default: ``'largest'``.

Returns
-------
out: Tuple[array, array]
a namedtuple ``(values, indices)`` whose

- first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``.
- second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``.

Notes
-----

- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements.
- The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values.
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
"""


def top_k_indices(
x: array,
k: int,
/,
*,
axis: Optional[int] = None,
mode: Literal["largest", "smallest"] = "largest",
) -> array:
"""
Returns the indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.

Parameters
----------
x: array
input array. Should have a real-valued data type.
k: int
number of elements to find. Must be a positive integer value.
axis: Optional[int]
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``.
mode: Literal['largest', 'smallest']
search mode. Must be one of the following modes:

- ``'largest'``: return the indices of the ``k`` largest elements.
- ``'smallest'``: return the indices of the ``k`` smallest elements.

Default: ``'largest'``.

Returns
-------
out: array
an array containing indices corresponding to the ``k`` largest (or smallest) elements of ``x``. The array must have the default array index data type. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)`` and contain the indices of a flattened ``x``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``.

Notes
-----

- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices.
- The order of the returned indices is left unspecified and thus implementation-dependent. Conforming implementations may return indices corresponding to sorted or unsorted values.
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
"""


def top_k_values(
x: array,
k: int,
/,
*,
axis: Optional[int] = None,
mode: Literal["largest", "smallest"] = "largest",
) -> array:
"""
Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.

Parameters
----------
x: array
input array. Should have a real-valued data type.
k: int
number of elements to find. Must be a positive integer value.
axis: Optional[int]
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``.
mode: Literal['largest', 'smallest']
search mode. Must be one of the following modes:

- ``'largest'``: return the indices of the ``k`` largest elements.
- ``'smallest'``: return the indices of the ``k`` smallest elements.

Default: ``'largest'``.

Returns
-------
out: array
an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``.

Notes
-----

- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices.
- The order of the returned values is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values.
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
"""


def where(condition: array, x1: array, x2: array, /) -> array:
"""
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
Expand Down
Loading