Skip to content

DataArray: propagate index coordinates with non-array dimensions #10116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ New Features
By `Benoit Bovy <https://github.com/benbovy>`_.
- Support reading to `GPU memory with Zarr <https://zarr.readthedocs.io/en/stable/user-guide/gpu.html>`_ (:pull:`10078`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray`, enabling
support for CF boundaries coordinate (e.g., ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10116`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
29 changes: 15 additions & 14 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
Expand Down Expand Up @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
Expand Down Expand Up @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

Expand Down Expand Up @@ -964,22 +964,23 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.dataarray import check_dataarray_coords

coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
coords_dims = set(calculate_dimensions(coords_plus_data))
obj_dims = set(self.dims)

if coords_dims > obj_dims:
# need more checks
check_dataarray_coords(
self._data.shape, Coordinates(coords, indexes), self.dims
)
self._data._coords = coords

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down
58 changes: 47 additions & 11 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,59 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
sizes = dict(zip(dim, shape, strict=True))
for k, v in coords.items():
if any(d not in dim for d in v.dims):
def check_dataarray_coords(
shape: tuple[int, ...], coords: Coordinates, dims: tuple[Hashable, ...]
):
"""Check that ``coords`` dimension names and sizes do not conflict with
array ``shape`` and dimensions ``dims``.

If a coordinate is associated with an index, the coordinate may have any
arbitrary dimension(s) as long as the index's dimensions (i.e., the union of
the dimensions of all coordinates associated with this index) intersects the
array dimensions.

If a coordinate has no index, then its dimensions much match (or be a subset
of) the array dimensions. Scalar coordinates are also allowed.

"""
indexes = coords.xindexes
skip_check_coord_name: set[Hashable] = set()
skip_check_dim_size: set[Hashable] = set()

# check dimension names
for name, var in coords.items():
if name in skip_check_coord_name:
continue
elif name in indexes:
index_dims = indexes.get_all_dims(name)
if any(d in dims for d in index_dims):
# can safely skip checking index's non-array dimensions
# and index's other coordinates since those must be all
# included in the dataarray so the index is not corrupted
skip_check_coord_name.update(indexes.get_all_coords(name))
skip_check_dim_size.update(d for d in index_dims if d not in dims)
raise_error = False
else:
raise_error = True
else:
raise_error = any(d not in dims for d in var.dims)
if raise_error:
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
f"coordinate {name} has dimensions {var.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
f"dimensions {dims}"
)
Comment on lines +152 to 174
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time following the logic in this loop. Could you please separate it into two separate loops:

  1. Update the sets of all coords/names to skip
  2. Raise any necessary errors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably even better checking indexes and non-indexed coordinates in two separate loops, as you suggests in #10116 (comment).


for d, s in v.sizes.items():
if s != sizes[d]:
# check dimension sizes
sizes = dict(zip(dims, shape, strict=True))

for name, var in coords.items():
for d, s in var.sizes.items():
if d not in skip_check_dim_size and s != sizes[d]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the only case in which a dimension is in skip_check_dim_size is when it is defined only on an index, in which case it won't appear in sizes.

In that case, I think we could simplify this by checking sizes instead of creating the skip_check_dim_size set:

Suggested change
if d not in skip_check_dim_size and s != sizes[d]:
if d in sizes and s != sizes[d]:

raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
f"coordinate {k!r}"
f"coordinate {name!r}"
)


Expand Down Expand Up @@ -212,8 +249,6 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)

return new_coords, dims_tuple


Expand Down Expand Up @@ -487,6 +522,7 @@ def __init__(

if not isinstance(coords, Coordinates):
coords = create_coords_with_default_indexes(coords)
check_dataarray_coords(data.shape, coords, dims)
indexes = dict(coords.xindexes)
coords = {k: v.copy() for k, v in coords.variables.items()}

Expand Down
14 changes: 12 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,10 +1210,20 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
needed_dims = set(variable.dims)

coords: dict[Hashable, Variable] = {}
temp_indexes = self.xindexes
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
coords[k] = self._variables[k]
if k in self._coord_names:
if (
k not in coords
and k in temp_indexes
and set(temp_indexes.get_all_dims(k)) & needed_dims
):
# add all coordinates of each index that shares at least one dimension
# with the dimensions of the extracted variable
coords.update(temp_indexes.get_all_coords(k))
Comment on lines +1217 to +1224
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coud this use a separate loop after the existing loop instead? e.g.,

for k in self._indexes:
    if k in coords:
        coords.update(self.xindexes.get_all_coords(k))

Or if we allow indexes without a coordinate of the same name:

for k in self._indexes:
    if set(self.xindexes.get_all_dims(k)) & needed_dims:
        coords.update(self.xindexes.get_all_coords(k))

Ideally, I would like the logic here to be just as simple as the words describing how it works, so a comment is not necessary!

elif set(self._variables[k].dims) <= needed_dims:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __getitem__(self, key: Hashable) -> "DataArray":
raise KeyError(key)

def __repr__(self) -> str:
return formatting.data_vars_repr(self)
return formatting.data_vars_repr(self.variables)

@property
def variables(self) -> Mapping[Hashable, Variable]:
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None):
if col_width is None:
col_width = _calculate_col_width(coords)
return _mapping_repr(
coords,
coords.variables,
title="Coordinates",
summarizer=summarize_variable,
expand_option_name="display_expand_coords",
Expand Down Expand Up @@ -743,7 +743,9 @@ def dataset_repr(ds):
if unindexed_dims_str:
summary.append(unindexed_dims_str)

summary.append(data_vars_repr(ds.data_vars, col_width=col_width, max_rows=max_rows))
summary.append(
data_vars_repr(ds.data_vars.variables, col_width=col_width, max_rows=max_rows)
)

display_default_indexes = _get_boolean_with_default(
"display_default_indexes", False
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ def summarize_variable(name, var, is_index=False, dtype=None) -> str:
)


def summarize_coords(variables) -> str:
def summarize_coords(coords) -> str:
li_items = []
for k, v in variables.items():
li_content = summarize_variable(k, v, is_index=k in variables.xindexes)
indexes = coords.xindexes
for k, v in coords.variables.items():
li_content = summarize_variable(k, v, is_index=k in indexes)
li_items.append(f"<li class='xr-var-item'>{li_content}</li>")

vars_li = "".join(li_items)
Expand Down Expand Up @@ -339,7 +340,7 @@ def dataset_repr(ds) -> str:
sections = [
dim_section(ds),
coord_section(ds.coords),
datavar_section(ds.data_vars),
datavar_section(ds.data_vars.variables),
index_section(_get_indexes_dict(ds.xindexes)),
attr_section(ds.attrs),
]
Expand Down Expand Up @@ -415,7 +416,7 @@ def datatree_node_repr(group_title: str, node: DataTree, show_inherited=False) -
sections.append(inherited_coord_section(inherited_coords))

sections += [
datavar_section(ds.data_vars),
datavar_section(ds.data_vars.variables),
attr_section(ds.attrs),
]

Expand Down
11 changes: 10 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
"""

_index_type: type[Index] | type[pd.Index]
_index_dims: dict[Hashable, Mapping[Hashable, int]]
_indexes: dict[Any, T_PandasOrXarrayIndex]
_variables: dict[Any, Variable]

Expand All @@ -1576,6 +1577,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
"__id_coord_names",
"__id_index",
"_dims",
"_index_dims",
"_index_type",
"_indexes",
"_variables",
Expand Down Expand Up @@ -1619,6 +1621,7 @@ def __init__(
)

self._index_type = index_type
self._index_dims = {}
self._indexes = dict(**indexes)
self._variables = dict(**variables)

Expand Down Expand Up @@ -1737,7 +1740,13 @@ def get_all_dims(
"""
from xarray.core.variable import calculate_dimensions

return calculate_dimensions(self.get_all_coords(key, errors=errors))
if key in self._index_dims:
return self._index_dims[key]
else:
dims = calculate_dimensions(self.get_all_coords(key, errors=errors))
if dims:
self._index_dims[key] = dims
return dims

def group_by_index(
self,
Expand Down
22 changes: 15 additions & 7 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,13 @@ def _assert_indexes_invariants_checks(
k: type(v) for k, v in indexes.items()
}

index_vars = {
k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
if check_default:
index_vars = {
k
for k, v in possible_coord_variables.items()
if isinstance(v, IndexVariable)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)

# check pandas index wrappers vs. coordinate data adapters
for k, index in indexes.items():
Expand Down Expand Up @@ -399,9 +402,14 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):
da.dims,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the variables in _coords different from those on .coords? I'm surprised this is still true now...

We should add one more invariant: something like "a dimension in da.coords.dims must be in da.dims or be associated with an index associated with a dimension in da.dims"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! The tests I added compare Coordinates objects, which thus skip this invariant check that we indeed need to update.

{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

if check_default_indexes:
assert all(
isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

for k, v in da._coords.items():
_assert_variable_invariants(v, k)

Expand Down
41 changes: 41 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,26 @@ class CustomIndex(Index): ...
# test coordinate variables copied
assert da.coords["x"] is not coords.variables["x"]

def test_constructor_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

actual = DataArray([1.0, 2.0], coords=coords, dims="x")

assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_equals_and_identical(self) -> None:
orig = DataArray(np.arange(5.0), {"a": 42}, dims="x")

Expand Down Expand Up @@ -1634,6 +1654,27 @@ def test_assign_coords_no_default_index(self) -> None:
assert_identical(actual.coords, coords, check_default_indexes=False)
assert "y" not in actual.xindexes

def test_assign_coords_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

da = DataArray([1.0, 2.0], dims="x")
actual = da.assign_coords(coords)

assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_coords_alignment(self) -> None:
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])
Expand Down
Loading
Loading