Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Stricter" option? #233

Open
awf opened this issue Jan 9, 2025 · 2 comments
Open

"Stricter" option? #233

awf opened this issue Jan 9, 2025 · 2 comments

Comments

@awf
Copy link

awf commented Jan 9, 2025

I would love to write code that is within the array API, but which still uses the GPU.

I get the "Avoid Restricting Behavior that is Outside the Scope of the Standard" clause in https://data-apis.org/array-api-compat/dev/special-considerations.html, but I would really prefer to replace xp.blah(..) with xp.back.blah(..) in order to mark places where I am going outside the standard.

I also understand that I could/should use array-api-strict in unit tests, and then use array-api-compat in the main code, but some tests are too slow to run on CPU, and for machine learning code, it's often quite hard to get good test coverage - it still seems valuable to get early indication that code I develop in PyTorch, say, has a good chance of porting to JAX.

For my use case, I would like a mode where torch/__init__.py replaces

from torch import * # noqa: F403

with

from torch import (  # noqa: F403
    abs, # Explicit list of known-compliant functions, excluding those defined in _aliases, _linalg etc
    acos,
    acosh,
    argmax,
...
    tile,
    trunc,
    uint8,
    zeros_like,
)

I understand that I might still inadvertently use getattr methods which are not in the API, but that is relatively easy to overcome (just avoid chaining and dot methods).

@ev-br
Copy link
Contributor

ev-br commented Jan 10, 2025

IIUC the classic approach would be to test on torch CPU + array-api-strict in some CI. This will ensure that the array api compliant code runs on torch GPU and jax GPU. Could you explain how this is limiting in your use case?

@awf
Copy link
Author

awf commented Jan 10, 2025

I'm mainly thinking of AI/ML use cases where conventional unit testing is hard - scaling down to a meaningful CPU-scale test can be time-consuming, perhaps as costly in time as just maintaining the port to another framework. Of course, one might argue that people who aren't willing to write proper unit tests will not care about cross-compatibility, but I don't think that's true. There is huge value in array-api, even for people who are writing "research code", which is undergoing high churn, and for which tests might lag the main codebase.

Secondly, I like to test the code I'm actually going to run - testing for accuracy on CPU when the code will run on GPU is fraught with difficulty. In such cases, we often have a GPU machine available for CI, so running a separate CPU test is an additional effort.

Thirdly, right now, I may want to use array_api_compat.size. It's the best solution to a problem which will probably not be solved for some time. Naturally, array-api-strict doesn't offer size (nor should it, but it means I can't write cross-framework code that tests against array-api-strict).

I might invert the question: assume that I have taken my current code, in framework X:

# foo, not portable
def foo(x):
  r = (x * x).sum(1).tanh()
  r = torch.max(r, -r)
  return (x.max() > 0)*r

And I have converted it to array-api, in order that I can run it on muliple frameworks (and get the readability advantages that a single source of documentation confers):

# foo, portable
def foo_portable(x):
  xp = array_api_compat.array_namespace(x)
  r = (x * x)
  r = xp.sum(r, axis=1)
  r = xp.tanh(r)
  r = xp.maximum(r, -r)
  return xp.astype(xp.max(r) > 0, xp.float32)*r

That was a certain amount of effort, which I achieved with some side-by-side running using array-api-strict.

Now, while experimenting, I decide to replace tanh with exp2

def foo_v2(x):
  xp = array_api_compat.array_namespace(x)
  r = (x * x)
  r = xp.sum(r, axis=1)
  r = xp.exp2(r)
  r = xp.maximum(r, -r)
  return xp.astype(xp.max(r) > 0, xp.float32)*r

In this proposed array-api-compat "stricter" mode, I will get an error telling me that xp.exp2 is not defined, which is helpful: I simply replace it with xp.backend.exp2 and proceed. I still have mostly-compatible code, which happens to work on PyTorch and JAX, and when I want to port it to "proper" array-api, I can easily see places where I have used backend.

Given the effort already expended in porting foo, the additional effort of adding .backend to some methods is tiny. It means I can continue to work until it's a good time to start polishing my pull request, at which point I can figure out how best to multi-framework exp2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants