-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
"Stricter" option? #233
Comments
IIUC the classic approach would be to test on torch CPU + array-api-strict in some CI. This will ensure that the array api compliant code runs on torch GPU and jax GPU. Could you explain how this is limiting in your use case? |
I'm mainly thinking of AI/ML use cases where conventional unit testing is hard - scaling down to a meaningful CPU-scale test can be time-consuming, perhaps as costly in time as just maintaining the port to another framework. Of course, one might argue that people who aren't willing to write proper unit tests will not care about cross-compatibility, but I don't think that's true. There is huge value in Secondly, I like to test the code I'm actually going to run - testing for accuracy on CPU when the code will run on GPU is fraught with difficulty. In such cases, we often have a GPU machine available for CI, so running a separate CPU test is an additional effort. Thirdly, right now, I may want to use I might invert the question: assume that I have taken my current code, in framework X: # foo, not portable
def foo(x):
r = (x * x).sum(1).tanh()
r = torch.max(r, -r)
return (x.max() > 0)*r And I have converted it to array-api, in order that I can run it on muliple frameworks (and get the readability advantages that a single source of documentation confers): # foo, portable
def foo_portable(x):
xp = array_api_compat.array_namespace(x)
r = (x * x)
r = xp.sum(r, axis=1)
r = xp.tanh(r)
r = xp.maximum(r, -r)
return xp.astype(xp.max(r) > 0, xp.float32)*r That was a certain amount of effort, which I achieved with some side-by-side running using array-api-strict. Now, while experimenting, I decide to replace def foo_v2(x):
xp = array_api_compat.array_namespace(x)
r = (x * x)
r = xp.sum(r, axis=1)
r = xp.exp2(r)
r = xp.maximum(r, -r)
return xp.astype(xp.max(r) > 0, xp.float32)*r In this proposed array-api-compat "stricter" mode, I will get an error telling me that Given the effort already expended in porting |
I would love to write code that is within the array API, but which still uses the GPU.
I get the "Avoid Restricting Behavior that is Outside the Scope of the Standard" clause in https://data-apis.org/array-api-compat/dev/special-considerations.html, but I would really prefer to replace
xp.blah(..)
withxp.back.blah(..)
in order to mark places where I am going outside the standard.I also understand that I could/should use
array-api-strict
in unit tests, and then usearray-api-compat
in the main code, but some tests are too slow to run on CPU, and for machine learning code, it's often quite hard to get good test coverage - it still seems valuable to get early indication that code I develop in PyTorch, say, has a good chance of porting to JAX.For my use case, I would like a mode where
torch/__init__.py
replaceswith
I understand that I might still inadvertently use
getattr
methods which are not in the API, but that is relatively easy to overcome (just avoid chaining and dot methods).The text was updated successfully, but these errors were encountered: