Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

take does not specify behavior for 0-D input array #855

Closed
ndgrigorian opened this issue Nov 23, 2024 · 1 comment · Fixed by #876
Closed

take does not specify behavior for 0-D input array #855

ndgrigorian opened this issue Nov 23, 2024 · 1 comment · Fixed by #876
Milestone

Comments

@ndgrigorian
Copy link

ndgrigorian commented Nov 23, 2024

It is not clarified in the spec how take should handle 0-D inputs for the indexed (x) argument.

NumPy allows this, even when axis=0

In [1]: import numpy as np

In [2]: np.take(np.ones(()), np.zeros(2, dtype="i4"), axis=0)
Out[2]: array([1., 1.])

of course, this won't work when using Python-sequence-style indexing

In [4]: np.ones(())[np.zeros(2, dtype="i4")]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[4], line 1
----> 1 np.ones(())[np.zeros(2, dtype="i4")]

IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed

Referring to take(x, indices, axis=3 as "conceptually equivalent" to x[:, :, :, indices, ...] implicitly disallows this behavior.

The description of the output seems to agree

out (array) – an array having the same data type as x. The output array must have the same rank (i.e., number of dimensions) as x and must have the same shape as x, except for the axis specified by axis whose size must equal the number of elements in indices.

This wouldn't make much sense for a 0-D array, much like in cumulative_sum a similar conclusion was drawn.

@asmeurer
Copy link
Member

I agree this is very similar to the discussion at #797. There numpy also allows it, but as @seberg pointed out, this is really legacy numpy behavior being too lax about this sort of thing. Note that here as with cumulative_sum, the issue is also that axis=0 doesn't make sense for a 0-D input, as it doesn't have any axes. With cumulative_sum, we decided to leave it undefined https://github.com/data-apis/array-api/pull/851/files. It really should be an error, but there's no reason for the standard to take a hard line here especially given some libraries do allow it. That same conclusion makes sense to me here too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants