Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use np.real and np.imag when possible #9368

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
20 changes: 13 additions & 7 deletions xarray/namedarray/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def astype(
xp = x._data.__array_namespace__()
return x._new(data=xp.astype(x._data, dtype, copy=copy))

# np.astype doesn't exist yet:
# TODO: np.astype only exists in np 2:
return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined]


Expand Down Expand Up @@ -106,9 +106,12 @@ def imag(
<xarray.NamedArray (x: 2)> Size: 16B
array([2., 4.])
"""
xp = _get_data_namespace(x)
out = x._new(data=xp.imag(x._data))
return out
if isinstance(x._data, _arrayapi):
xp = x._data.__array_namespace__()
return x._new(data=xp.imag(x._data))

# TODO: np.imag only exists in np 2:
return x._new(data=x._data.imag) # type: ignore[attr-defined]


def real(
Expand Down Expand Up @@ -139,9 +142,12 @@ def real(
<xarray.NamedArray (x: 2)> Size: 16B
array([1., 2.])
"""
xp = _get_data_namespace(x)
out = x._new(data=xp.real(x._data))
return out
if isinstance(x._data, _arrayapi):
xp = x._data.__array_namespace__()
return x._new(data=xp.real(x._data))

# TODO: np.real only exists in np 2:
return x._new(data=x._data.real) # type: ignore[attr-defined]


# %% Manipulation functions
Expand Down
6 changes: 0 additions & 6 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,6 @@ def __array_function__(
kwargs: Mapping[str, Any],
) -> Any: ...

@property
def imag(self) -> _arrayfunction[_ShapeType_co, Any]: ...

@property
def real(self) -> _arrayfunction[_ShapeType_co, Any]: ...


@runtime_checkable
class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]):
Expand Down
13 changes: 4 additions & 9 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@
attrs: _AttrsLike = None,
):
self._data = data
self._dims = self._parse_dimensions(dims)

Check warning on line 264 in xarray/namedarray/core.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.
self._attrs = dict(attrs) if attrs else None

def __init_subclass__(cls, **kwargs: Any) -> None:
Expand Down Expand Up @@ -565,12 +565,9 @@
--------
numpy.ndarray.imag
"""
if isinstance(self._data, _arrayapi):
from xarray.namedarray._array_api import imag
from xarray.namedarray._array_api import imag

return imag(self)

return self._new(data=self._data.imag)
return imag(self)

@property
def real(
Expand All @@ -583,11 +580,9 @@
--------
numpy.ndarray.real
"""
if isinstance(self._data, _arrayapi):
from xarray.namedarray._array_api import real
from xarray.namedarray._array_api import real

return real(self)
return self._new(data=self._data.real)
return real(self)

def __dask_tokenize__(self) -> object:
# Use v.data, instead of v._data, in order to cope with the wrappers
Expand Down
Loading