Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_finfo fails on numpy #317

Open
ev-br opened this issue Nov 20, 2024 · 10 comments
Open

test_finfo fails on numpy #317

ev-br opened this issue Nov 20, 2024 · 10 comments

Comments

@ev-br
Copy link
Contributor

ev-br commented Nov 20, 2024

The spec seems to imply that xp.finfo(xp.float32).eps is a python float, but numpy and jax.numpy use numpy scalars instead

In [1]: import torch

In [2]: torch.finfo(torch.float32).eps
Out[2]: 1.1920928955078125e-07

In [3]: type(torch.finfo(torch.float32).eps)
Out[3]: float

In [4]: import numpy as np

In [5]: type(np.finfo(np.float32).eps)
Out[5]: numpy.float32

In [6]: import jax.numpy as jnp

In [7]: type(jnp.finfo(jnp.float32).eps)
Out[7]: numpy.float32

Not hard to work around on the -compat level, even if having a wrapper for finfo feels a bit cheesy.

@jakevdp
Copy link
Contributor

jakevdp commented Nov 20, 2024

xref data-apis/array-api#405

JAX explicitly disables this test: https://github.com/jax-ml/jax/blob/40fc6598f96999271a3c19cfaab6f02579c003d6/tests/array_api_skips.txt#L3-L4 We've followed NumPy here and are waiting for a resolution of the above issue.

@rgommers
Copy link
Member

Not hard to work around on the -compat level, even if having a wrapper for finfo feels a bit cheesy.

There should not be a need to work around this, that's probably counterproductive. NumPy scalars duck type with float, which is fine.

We've followed NumPy here and are waiting for a resolution of the above issue.

I re-read this issue quickly; it seems like all we need is a doc update in the standard that says "duck typing is allowed", and then making this test a bit more flexible perhaps (not sure, also okay to keep the skips). The current state in NumPy and JAX should be fine.

@ev-br the NumPy skip list is maintained at https://github.com/numpy/numpy/blob/fc7cc1ec4988ce8b73766434930c9df2661ada59/tools/ci/array-api-xfails.txt#L2, so those are the known incompatibilities - a mix of limitations in the test suite and a few missing things still to be implemented in NumPy.

@ev-br
Copy link
Contributor Author

ev-br commented Nov 21, 2024

NumPy scalars duck type with float, which is fine ... all we need is a doc update in the standard that says "duck typing is allowed"

Not sure how that would help. f32 is not related to python float

In [18]: issubclass(np.float64, float)
Out[18]: True

In [19]: issubclass(np.float32, float)
Out[19]: False

Keeping the skips + closing this issue as a wontfix is probably the way to go indeed, unless NumPy implements the change (which I think makes sense, but it's up to NumPy).

@asmeurer
Copy link
Member

Yes, we agreed a while ago that duck-typed floats should be OK, but the tests were never updated.

@pearu
Copy link

pearu commented Nov 21, 2024

FWIW, as an user of finfo, I have found that duck-typing of finfo instance attributes is a very useful feature.

@asmeurer
Copy link
Member

@pearu can you clarify what you mean? np.float64 duck-types as float in some ways but not in others. Is the current behavior sufficient?

@rgommers
Copy link
Member

I think in 99% of cases, the duck typing you need is that it works with Python operators and is accepted by NumPy functions that accept float scalars. Which is true for both float64 and float32. That isinstance, issubclass and type don't work is not all that relevant in practice I believe.

@pearu
Copy link

pearu commented Nov 22, 2024

@pearu can you clarify what you mean? np.float64 duck-types as float in some ways but not in others. Is the current behavior sufficient?

Yes. With earlier versions of numpy, I had to use numpy.float32(numpy.finfo(numpy.float32).eps), for instance, but with the current behavior where numpy.finfo(numpy.float32).eps returns a proper float instance (numpy.float32 in this case) is more practical. In fact, numpy.finfo(numpy.longdouble).max would not be possible if finfo attributes would be Python floats.

@asmeurer
Copy link
Member

You seem to be agreeing with the point I made at data-apis/array-api#405 that the return types for finfo should not be float.

@pearu
Copy link

pearu commented Nov 22, 2024

You seem to be agreeing with the point I made at data-apis/array-api#405 that the return types for finfo should not be float.

Yes, except numpy.finfo(float).eps should be float.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants