-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved duck array wrapping #9798
base: main
Are you sure you want to change the base?
Conversation
xarray/core/array_api_compat.py
Outdated
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
elif isinstance(x, array_type("cupy")): | ||
# special case cupy for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cupy seems to have full compliance with the standard, but doesn't yet actually have __array_namespace__
on the core API. Others may be the same?
xarray/core/computation.py
Outdated
@@ -2174,9 +2174,13 @@ def _calc_idxminmax( | |||
# we need to attach back the dim name | |||
res.name = dim | |||
else: | |||
indx.data = to_numpy(indx.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is exactly what we want but got the cupy tests working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I would lean aginst this. It looks like cupy does not support indexing cupy arrays with cupy arrays?
Or is the issue that the cupy arrays are used to index xarray coordinates?
) | ||
|
||
def sliding_window_view(array, window_shape, axis=None, **kwargs): | ||
# TODO: some libraries (e.g. jax) don't have this, implement an alternative? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one of the biggest outstanding bummers of wrapping jax arrays. There is apparently openness to adding this as an API (even though it would not offer any performance benefit in XLA). But given this is way outside the API standard, whether it makes sense to implement a general version within xarray that doesn't rely on stride tricks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could implement a version using "summed area tables" (basically run a single accumulator and then compute differences between the window edges); or convolutions I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have something that works pretty well with this style of gather operation. But only in a jit
context where XLA can work its magic. So I guess this is better left to the specific library to implement, or the user.
|
||
import xarray as xr | ||
|
||
# TODO: how to test these in CI? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: I just noticed xarray-array-testing
see also data-apis/array-api#621 for the higher level discussion of |
xarray/core/array_api_compat.py
Outdated
def get_array_namespace(*values): | ||
def _get_single_namespace(x): | ||
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
elif isinstance(x, array_type("cupy")): | ||
# special case cupy for now | ||
import cupy as cp | ||
|
||
return cp | ||
else: | ||
return np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably want to wrap array-api-compat's array_namespace
here, see https://github.com/scipy/scipy/blob/ec30b43e143ac0cb0e30129d4da0dbaa29e74c34/scipy/_lib/_array_api.py#L118-L152 for what we do in SciPy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to go that route, I did actually try that but array-api-compat doesn't handle a bunch of things we end up passing through this (scalars, index wrappers, etc) so it would require some careful prefiltering.
The only things this package effectively wraps that don't have __array_namespace__
are cupy, dask, and torch. I tried the torch wrapper but it doesn't pass our as_compatible_data check because the wrapper object itself doesn't have __array_namespace__
or __array_function__
😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support for scalars was just merged half an hour ago! data-apis/array-api-compat#147
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this would handle the array-like check well, but would require adding this as a core xarray dependency to use it in as_compatible_data
. Not sure if there is any appetite for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In SciPy we vendor array-api-compat via a Git submodule. There's a little bit of build system bookkeeping needed, but otherwise it works well without introducing a dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I played around with this some more.
Getting an object that is compliant from the perspective of array_api_compat.is_array_api_obj
to work through the hierarchy of xarray objects requires basically swapping in a bunch of more restrictive hasattr(__array_namespace__)
checks for this function. Not too bad.
I was able to get things to the point that xarray can wrap torch.Tensor
, which is pretty cool. But the reality is that xarray relies on a lot of functionality beyond the Array API, so this just doesn't work in any practical sense. Similarly, swapping in the more restrictive cupy.array_api
for the main cupy
namespace causes all kinds of things to break.
It seems torch has had little movement on supporting the standard since 2021. From xarray's perspective, cupy
implements everything we need, except for the actual __array_namespace__
attribute declaring compatibility, which seems to be planned for v14.
So while this compat module is nice in theory, I don't think it's very useful for xarray. cupy, jax, and sparse are good to go without it, and we only need a single special case for cupy to fetch its __array_namespace__
.
I looked back at def as_array_type(self, asarray: callable, **kwargs) -> Self:
# e.g. ds.to_array_type(jnp.asarray, device="gpu")
def is_array_type(self, array: type) -> bool:
# e.g. ds.is_array_type(cp.ndarray) |
👍 I though I'd seen an array API method for converting between compliant array types, but I can't find it now |
Compliant namespaces should now implement def to_namespace(self, xp: ModuleType, **kwargs) -> Self:
xp.from_dlpack(self.data)
# e.g. ds.to_namespace(cp) But this actually doesn't work for cupy, since they're quite stringent about implicit device transfers:
Also sparse doesn't have this at all, and it wouldn't be clear whether you want a |
xarray/core/computation.py
Outdated
@@ -2174,9 +2174,13 @@ def _calc_idxminmax( | |||
# we need to attach back the dim name | |||
res.name = dim | |||
else: | |||
indx.data = to_numpy(indx.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I would lean aginst this. It looks like cupy does not support indexing cupy arrays with cupy arrays?
Or is the issue that the cupy arrays are used to index xarray coordinates?
xarray/core/dataset.py
Outdated
if isinstance(v.data, array_type("cupy")): | ||
coord_data = duck_array_ops.get_array_namespace(v.data).asarray( | ||
coord_var.data | ||
) | ||
else: | ||
coord_data = coord_var.data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this could be encapsulated into an as_like_array()
helper function?
I'd like to keep the explicit array namespace stuff out of most xarray functions.
xarray/core/variable.py
Outdated
if type(mask) is not type(data): | ||
mask = duck_array_ops.get_array_namespace(data).asarray(mask) | ||
data = duck_array_ops.where( | ||
duck_array_ops.logical_not(mask), data, fill_value | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could also use as_like_array()
I think this is in pretty good shape now, except the question of whether to attempt any of this integration testing in CI. That could also be punted to xarray-array-testing (@keewis @TomNicholas) |
whats-new.rst
api.rst
Companion to #9776.
My attempts to use xarray-wrapped jax arrays turned up a bunch of limitations of our duck array wrapping. Jax is probably worst case of the major duck array types out there, because it doesn't implement either
__array_function__
or__array_ufunc__
to intercept numpy calls. Cupy, sparse, and probably others do, so callingnp.func(cp.ndarray)
generally works fine, but this usually converts your data to numpy with jax.A lot of this was just grepping around for hard-coded
np.
cases that we can easily dispatch toduck_array_ops
versions. I image some of these changes could be controversial, because a number of them (notablynanmean/std/var
,pad
,quantile
,einsum
,cross
) aren't technically part of the standard. See #8834 (comment) for discussion about the nan-skipping aggregations.It feels like a much better user experience though to try our best to dispatch to the correct backend, and error if the function isn't implemented, rather than blindly calling numpy. And practically, all major array backends that are feasible to wrap today (
cupy
,sparse
,jax
,cubed
,arkouda
, ...?) implement all of these functions.To test, I just ran down the API list and ran most functions to see if we maintain proper wrapping. Prior to the changes here, I had 28 jax failures and 9 cupy failures, while all (non-xfailed) ones now pass.
Basically everything works except the interp/missing methods which have a lot of specialized code. Also a few odds and ends like polyfit and rank.