-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Array API backend #317
Array API backend #317
Conversation
cd36f46
to
a9f91ae
Compare
a9f91ae
to
bda5010
Compare
@@ -19,6 +19,7 @@ | |||
float64, | |||
) | |||
from cubed.array_api.linear_algebra_functions import matmul | |||
from cubed.backend_array_api import namespace as nxp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simply import namespace as np
? This would minimize the diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would, but I wanted to signal that this may not be regular NumPy. By using nxp
here, it makes it easy to search for np
to find places that are still using regular NumPy.
@@ -0,0 +1,23 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm excited to see this as simple place to drop in a Jax array implementation :) Great work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be interesting to see if this works with the recent work in JAX on the Array API.
You might be able to avoid the extra dependency on array-api-compat now that numpy v1.26.0 introduced |
I vote merge this and think first about generalizing the array type in cases that don't also require changing the storage layer (e.g. cupy/sparse/pint). |
NumPy 1.22 introduced If we're worried about the extra dependency then array-api-compat can be vendored, but I'm not sure that's necessary. |
bda5010
to
7718ca4
Compare
Fixes #315
This uses https://github.com/data-apis/array-api-compat which makes NumPy conform to the array API. There are a couple of places where we use functions not in the array API (which fall back to NumPy):
take_along_axis
for arg reductions, and nan functions.