Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved duck array wrapping #9798

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

slevang
Copy link
Contributor

@slevang slevang commented Nov 18, 2024

Companion to #9776.

My attempts to use xarray-wrapped jax arrays turned up a bunch of limitations of our duck array wrapping. Jax is probably worst case of the major duck array types out there, because it doesn't implement either __array_function__ or __array_ufunc__ to intercept numpy calls. Cupy, sparse, and probably others do, so calling np.func(cp.ndarray) generally works fine, but this usually converts your data to numpy with jax.

A lot of this was just grepping around for hard-coded np. cases that we can easily dispatch to duck_array_ops versions. I image some of these changes could be controversial, because a number of them (notably nanmean/std/var, pad, quantile, einsum, cross) aren't technically part of the standard. See #8834 (comment) for discussion about the nan-skipping aggregations.

It feels like a much better user experience though to try our best to dispatch to the correct backend, and error if the function isn't implemented, rather than blindly calling numpy. And practically, all major array backends that are feasible to wrap today (cupy, sparse, jax, cubed, arkouda, ...?) implement all of these functions.

To test, I just ran down the API list and ran most functions to see if we maintain proper wrapping. Prior to the changes here, I had 28 jax failures and 9 cupy failures, while all (non-xfailed) ones now pass.

Basically everything works except the interp/missing methods which have a lot of specialized code. Also a few odds and ends like polyfit and rank.

if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
elif isinstance(x, array_type("cupy")):
# special case cupy for now
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cupy seems to have full compliance with the standard, but doesn't yet actually have __array_namespace__ on the core API. Others may be the same?

@@ -2174,9 +2174,13 @@ def _calc_idxminmax(
# we need to attach back the dim name
res.name = dim
else:
indx.data = to_numpy(indx.data)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is exactly what we want but got the cupy tests working.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I would lean aginst this. It looks like cupy does not support indexing cupy arrays with cupy arrays?

Or is the issue that the cupy arrays are used to index xarray coordinates?

)

def sliding_window_view(array, window_shape, axis=None, **kwargs):
# TODO: some libraries (e.g. jax) don't have this, implement an alternative?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of the biggest outstanding bummers of wrapping jax arrays. There is apparently openness to adding this as an API (even though it would not offer any performance benefit in XLA). But given this is way outside the API standard, whether it makes sense to implement a general version within xarray that doesn't rely on stride tricks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could implement a version using "summed area tables" (basically run a single accumulator and then compute differences between the window edges); or convolutions I guess.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have something that works pretty well with this style of gather operation. But only in a jit context where XLA can work its magic. So I guess this is better left to the specific library to implement, or the user.


import xarray as xr

# TODO: how to test these in CI?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I just noticed xarray-array-testing

@lucascolley
Copy link

see also data-apis/array-api#621 for the higher level discussion of nan reductions

Comment on lines 49 to 59
def get_array_namespace(*values):
def _get_single_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
elif isinstance(x, array_type("cupy")):
# special case cupy for now
import cupy as cp

return cp
else:
return np

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to wrap array-api-compat's array_namespace here, see https://github.com/scipy/scipy/blob/ec30b43e143ac0cb0e30129d4da0dbaa29e74c34/scipy/_lib/_array_api.py#L118-L152 for what we do in SciPy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to go that route, I did actually try that but array-api-compat doesn't handle a bunch of things we end up passing through this (scalars, index wrappers, etc) so it would require some careful prefiltering.

The only things this package effectively wraps that don't have __array_namespace__ are cupy, dask, and torch. I tried the torch wrapper but it doesn't pass our as_compatible_data check because the wrapper object itself doesn't have __array_namespace__ or __array_function__ 😕

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support for scalars was just merged half an hour ago! data-apis/array-api-compat#147

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this would handle the array-like check well, but would require adding this as a core xarray dependency to use it in as_compatible_data. Not sure if there is any appetite for that.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In SciPy we vendor array-api-compat via a Git submodule. There's a little bit of build system bookkeeping needed, but otherwise it works well without introducing a dependency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played around with this some more.

Getting an object that is compliant from the perspective of array_api_compat.is_array_api_obj to work through the hierarchy of xarray objects requires basically swapping in a bunch of more restrictive hasattr(__array_namespace__) checks for this function. Not too bad.

I was able to get things to the point that xarray can wrap torch.Tensor, which is pretty cool. But the reality is that xarray relies on a lot of functionality beyond the Array API, so this just doesn't work in any practical sense. Similarly, swapping in the more restrictive cupy.array_api for the main cupy namespace causes all kinds of things to break.

It seems torch has had little movement on supporting the standard since 2021. From xarray's perspective, cupy implements everything we need, except for the actual __array_namespace__ attribute declaring compatibility, which seems to be planned for v14.

So while this compat module is nice in theory, I don't think it's very useful for xarray. cupy, jax, and sparse are good to go without it, and we only need a single special case for cupy to fetch its __array_namespace__.

@slevang
Copy link
Contributor Author

slevang commented Nov 20, 2024

I looked back at cupy-xarray. Now that this duck array stuff all works pretty well, I'm wondering how people feel about adding official DataArray/Dataset methods analogous to this, but in a generalized way. Something like:

def as_array_type(self, asarray: callable, **kwargs) -> Self:
# e.g. ds.to_array_type(jnp.asarray, device="gpu")

def is_array_type(self, array: type) -> bool:
# e.g. ds.is_array_type(cp.ndarray)

@slevang slevang marked this pull request as ready for review November 20, 2024 15:47
@dcherian
Copy link
Contributor

methods analogous to this, but in a generalized way.

👍 I though I'd seen an array API method for converting between compliant array types, but I can't find it now

@slevang
Copy link
Contributor Author

slevang commented Nov 20, 2024

Compliant namespaces should now implement from_dlpack which is generally the recommended conversion protocol. So I suppose we could instead pass the namespace and hard code it to use that:

def to_namespace(self, xp: ModuleType, **kwargs) -> Self:
    xp.from_dlpack(self.data)
       
# e.g. ds.to_namespace(cp)

But this actually doesn't work for cupy, since they're quite stringent about implicit device transfers:

TypeError: CPU arrays cannot be directly imported to CuPy. Use `cupy.array(numpy.from_dlpack(input))` instead.

Also sparse doesn't have this at all, and it wouldn't be clear whether you want a COO, DOK, etc.

@@ -2174,9 +2174,13 @@ def _calc_idxminmax(
# we need to attach back the dim name
res.name = dim
else:
indx.data = to_numpy(indx.data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I would lean aginst this. It looks like cupy does not support indexing cupy arrays with cupy arrays?

Or is the issue that the cupy arrays are used to index xarray coordinates?

Comment on lines 8682 to 8687
if isinstance(v.data, array_type("cupy")):
coord_data = duck_array_ops.get_array_namespace(v.data).asarray(
coord_var.data
)
else:
coord_data = coord_var.data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this could be encapsulated into an as_like_array() helper function?

I'd like to keep the explicit array namespace stuff out of most xarray functions.

Comment on lines 864 to 868
if type(mask) is not type(data):
mask = duck_array_ops.get_array_namespace(data).asarray(mask)
data = duck_array_ops.where(
duck_array_ops.logical_not(mask), data, fill_value
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could also use as_like_array()

@slevang
Copy link
Contributor Author

slevang commented Nov 21, 2024

I think this is in pretty good shape now, except the question of whether to attempt any of this integration testing in CI. That could also be punted to xarray-array-testing (@keewis @TomNicholas)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants