From f93b45ad2803b3f0afae093d67c646842115c32e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 12 Mar 2025 12:59:08 +0100 Subject: [PATCH 01/11] DataArray constructor: update coords dims check --- xarray/core/dataarray.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4324a4587b3..d6c72b6af7e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -130,18 +130,30 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def _check_coords_dims(shape, coords, dim): - sizes = dict(zip(dim, shape, strict=True)) +def _check_coords_dims( + shape: tuple[int, ...], coords: Coordinates, dims: tuple[Hashable, ...] +): + sizes = dict(zip(dims, shape, strict=True)) + extra_index_dims = set() + for k, v in coords.items(): - if any(d not in dim for d in v.dims): + if any(d not in dims for d in v.dims): + # allow any coordinate associated with an index that shares at least + # one of dataarray's dimensions + indexes = coords.xindexes + if k in indexes: + index_dims = indexes.get_all_dims(k) + if any(d in dims for d in index_dims): + extra_index_dims.update(d for d in v.dims if d not in dims) + continue raise ValueError( f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " - f"dimensions {dim}" + f"dimensions {dims}" ) for d, s in v.sizes.items(): - if s != sizes[d]: + if d not in extra_index_dims and s != sizes[d]: raise ValueError( f"conflicting sizes for dimension {d!r}: " f"length {sizes[d]} on the data but length {s} on " @@ -212,8 +224,6 @@ def _infer_coords_and_dims( var.dims = (dim,) new_coords[dim] = var.to_index_variable() - _check_coords_dims(shape, new_coords, dims_tuple) - return new_coords, dims_tuple @@ -487,6 +497,7 @@ def __init__( if not isinstance(coords, Coordinates): coords = create_coords_with_default_indexes(coords) + _check_coords_dims(data.shape, coords, dims) indexes = dict(coords.xindexes) coords = {k: v.copy() for k, v in coords.variables.items()} From 01fc5fdc454575ce3a97404a1958220b13f650bb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 12 Mar 2025 13:00:22 +0100 Subject: [PATCH 02/11] get datarray from dataset: update coords selection --- xarray/core/dataset.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 79a2dde3444..93675f9293f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1210,10 +1210,20 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: needed_dims = set(variable.dims) coords: dict[Hashable, Variable] = {} + indexes = self.xindexes # preserve ordering for k in self._variables: - if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: - coords[k] = self._variables[k] + if k in self._coord_names: + if ( + k not in coords + and k in indexes + and set(indexes.get_all_dims(k)) & needed_dims + ): + # add all coordinates of each index that shares at least one dimension + # with the dimensions of the extracted variable + coords.update(indexes.get_all_coords(k)) + elif set(self._variables[k].dims) <= needed_dims: + coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) From 75b9ae5c47ff2d22003e5e65879222a04447503e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 12 Mar 2025 17:18:45 +0100 Subject: [PATCH 03/11] fix typing issues --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d6c72b6af7e..ce89f5499c7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -134,7 +134,7 @@ def _check_coords_dims( shape: tuple[int, ...], coords: Coordinates, dims: tuple[Hashable, ...] ): sizes = dict(zip(dims, shape, strict=True)) - extra_index_dims = set() + extra_index_dims: set[str] = set() for k, v in coords.items(): if any(d not in dims for d in v.dims): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 93675f9293f..0c1decbf3a9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1210,18 +1210,18 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: needed_dims = set(variable.dims) coords: dict[Hashable, Variable] = {} - indexes = self.xindexes + temp_indexes = self.xindexes # preserve ordering for k in self._variables: if k in self._coord_names: if ( k not in coords - and k in indexes - and set(indexes.get_all_dims(k)) & needed_dims + and k in temp_indexes + and set(temp_indexes.get_all_dims(k)) & needed_dims ): # add all coordinates of each index that shares at least one dimension # with the dimensions of the extracted variable - coords.update(indexes.get_all_coords(k)) + coords.update(temp_indexes.get_all_coords(k)) elif set(self._variables[k].dims) <= needed_dims: coords[k] = self._variables[k] From b647df78d1c49f0ac6d52d776358586cdafa9594 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 12 Mar 2025 17:19:19 +0100 Subject: [PATCH 04/11] DataArray update coords: update coords dims check --- xarray/core/coordinates.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 47773ddfbb6..f056280e0b8 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool: return self.to_dataset().identical(other.to_dataset()) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: # redirect to DatasetCoordinates._update_coords self._data.coords._update_coords(coords, indexes) @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset: return self._data._copy_listed(names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: variables = self._data._variables.copy() variables.update(coords) @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset: return self._data.dataset._copy_listed(self._names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: from xarray.core.datatree import check_alignment @@ -964,22 +964,32 @@ def __getitem__(self, key: Hashable) -> T_DataArray: return self._data._getitem_coord(key) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: coords_plus_data = coords.copy() coords_plus_data[_THIS_ARRAY] = self._data.variable dims = calculate_dimensions(coords_plus_data) - if not set(dims) <= set(self.dims): - raise ValueError( - "cannot add coordinates with new dimensions to a DataArray" - ) - self._data._coords = coords - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + if set(dims) > set(self.dims): + for k, v in coords.items(): + if any(d not in self.dims for d in v.dims): + # allow any coordinate associated with an index that shares at least + # one of dataarray's dimensions + temp_indexes = Indexes( + indexes, {k: v for k, v in coords.items() if k in indexes} + ) + if k in indexes: + index_dims = temp_indexes.get_all_dims(k) + if any(d in self.dims for d in index_dims): + continue + raise ValueError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {self.dims}" + ) + + self._data._coords = coords + self._data._indexes = indexes def _drop_coords(self, coord_names): # should drop indexed coordinates only From 90ce0f97e8d22bf3ac5441c2de861219b62e36b0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Mar 2025 08:44:35 +0100 Subject: [PATCH 05/11] rename arg dims > dim --- xarray/core/dataarray.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ce89f5499c7..9e4239d11d8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -131,25 +131,25 @@ def _check_coords_dims( - shape: tuple[int, ...], coords: Coordinates, dims: tuple[Hashable, ...] + shape: tuple[int, ...], coords: Coordinates, dim: tuple[Hashable, ...] ): - sizes = dict(zip(dims, shape, strict=True)) + sizes = dict(zip(dim, shape, strict=True)) extra_index_dims: set[str] = set() for k, v in coords.items(): - if any(d not in dims for d in v.dims): + if any(d not in dim for d in v.dims): # allow any coordinate associated with an index that shares at least # one of dataarray's dimensions indexes = coords.xindexes if k in indexes: index_dims = indexes.get_all_dims(k) - if any(d in dims for d in index_dims): - extra_index_dims.update(d for d in v.dims if d not in dims) + if any(d in dim for d in index_dims): + extra_index_dims.update(d for d in v.dims if d not in dim) continue raise ValueError( f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " - f"dimensions {dims}" + f"dimensions {dim}" ) for d, s in v.sizes.items(): From ebcfa229ce79cc04bdf036f3edbdc46b65c8c28f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Mar 2025 09:20:56 +0100 Subject: [PATCH 06/11] add tests --- xarray/tests/test_dataarray.py | 49 ++++++++++++++++++++++++++++++++++ xarray/tests/test_dataset.py | 25 +++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 75d6d919e19..a44bff66409 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -529,6 +529,30 @@ class CustomIndex(Index): ... # test coordinate variables copied assert da.coords["x"] is not coords.variables["x"] + def test_constructor_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + actual = DataArray([1.0, 2.0], coords=coords, dims="x") + + # cannot use `assert_identical()` test utility function here yet + # (indexes invariant check is still based on IndexVariable, which + # doesn't work with AnyIndex coordinate variables here) + assert actual.coords.to_dataset().equals(coords.to_dataset()) + assert list(actual.coords.xindexes) == list(coords.xindexes) + assert "x_bnds" not in actual.dims + def test_equals_and_identical(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") @@ -1634,6 +1658,31 @@ def test_assign_coords_no_default_index(self) -> None: assert_identical(actual.coords, coords, check_default_indexes=False) assert "y" not in actual.xindexes + def test_assign_coords_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + da = DataArray([1.0, 2.0], dims="x") + actual = da.assign_coords(coords) + + # cannot use `assert_identical()` test utility function here yet + # (indexes invariant check is still based on IndexVariable, which + # doesn't work with AnyIndex coordinate variables here) + assert actual.coords.to_dataset().equals(coords.to_dataset()) + assert list(actual.coords.xindexes) == list(coords.xindexes) + assert "x_bnds" not in actual.dims + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index bdae9daf758..a0fd7ebc6c2 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4206,6 +4206,31 @@ def test_getitem_multiple_dtype(self) -> None: dataset = Dataset({key: ("dim0", range(1)) for key in keys}) assert_identical(dataset, dataset[keys]) + def test_getitem_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords) + actual = ds["foo"] + + # cannot use `assert_identical()` test utility function here yet + # (indexes invariant check is still based on IndexVariable, which + # doesn't work with AnyIndex coordinate variables here) + assert actual.coords.to_dataset().equals(coords.to_dataset()) + assert list(actual.coords.xindexes) == list(coords.xindexes) + assert "x_bnds" not in actual.dims + def test_virtual_variables_default_coords(self) -> None: dataset = Dataset({"foo": ("x", range(10))}) expected1 = DataArray(range(10), dims="x", name="x") From f27f2d010eafdfb99aa4738742462ad3cc90a077 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Mar 2025 09:30:44 +0100 Subject: [PATCH 07/11] update whats new --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 994fc70339c..615be8e019f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ New Features By `Benoit Bovy `_. - Support reading to `GPU memory with Zarr `_ (:pull:`10078`). By `Deepak Cherian `_. +- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray`, enabling + support for CF boundaries coordinate (e.g., ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10116`). + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ From 528f3565d653b31f3bfa36160b3d204a616db58b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Mar 2025 14:24:41 +0100 Subject: [PATCH 08/11] cache Indexes.get_all_dims --- xarray/core/indexes.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c2bc8b94f3f..09d685f296f 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1568,6 +1568,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): """ _index_type: type[Index] | type[pd.Index] + _index_dims: dict[Hashable, Mapping[Hashable, int]] _indexes: dict[Any, T_PandasOrXarrayIndex] _variables: dict[Any, Variable] @@ -1576,6 +1577,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): "__id_coord_names", "__id_index", "_dims", + "_index_dims", "_index_type", "_indexes", "_variables", @@ -1619,6 +1621,7 @@ def __init__( ) self._index_type = index_type + self._index_dims = {} self._indexes = dict(**indexes) self._variables = dict(**variables) @@ -1737,7 +1740,13 @@ def get_all_dims( """ from xarray.core.variable import calculate_dimensions - return calculate_dimensions(self.get_all_coords(key, errors=errors)) + if key in self._index_dims: + return self._index_dims[key] + else: + dims = calculate_dimensions(self.get_all_coords(key, errors=errors)) + if dims: + self._index_dims[key] = dims + return dims def group_by_index( self, From 695fb8611c16d64a4f1b199e1324ecfc53592c5e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 14 Mar 2025 11:20:20 +0100 Subject: [PATCH 09/11] rework checking dataarray coordinates --- xarray/core/coordinates.py | 29 ++++++----------- xarray/core/dataarray.py | 67 ++++++++++++++++++++++++++------------ 2 files changed, 56 insertions(+), 40 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index f056280e0b8..4eb089314a7 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -966,27 +966,18 @@ def __getitem__(self, key: Hashable) -> T_DataArray: def _update_coords( self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: + from xarray.core.dataarray import check_dataarray_coords + coords_plus_data = coords.copy() coords_plus_data[_THIS_ARRAY] = self._data.variable - dims = calculate_dimensions(coords_plus_data) - - if set(dims) > set(self.dims): - for k, v in coords.items(): - if any(d not in self.dims for d in v.dims): - # allow any coordinate associated with an index that shares at least - # one of dataarray's dimensions - temp_indexes = Indexes( - indexes, {k: v for k, v in coords.items() if k in indexes} - ) - if k in indexes: - index_dims = temp_indexes.get_all_dims(k) - if any(d in self.dims for d in index_dims): - continue - raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {self.dims}" - ) + coords_dims = set(calculate_dimensions(coords_plus_data)) + obj_dims = set(self.dims) + + if coords_dims > obj_dims: + # need more checks + check_dataarray_coords( + self._data.shape, Coordinates(coords, indexes), self.dims + ) self._data._coords = coords self._data._indexes = indexes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 9e4239d11d8..578ce63e789 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -130,34 +130,59 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def _check_coords_dims( - shape: tuple[int, ...], coords: Coordinates, dim: tuple[Hashable, ...] +def check_dataarray_coords( + shape: tuple[int, ...], coords: Coordinates, dims: tuple[Hashable, ...] ): - sizes = dict(zip(dim, shape, strict=True)) - extra_index_dims: set[str] = set() - - for k, v in coords.items(): - if any(d not in dim for d in v.dims): - # allow any coordinate associated with an index that shares at least - # one of dataarray's dimensions - indexes = coords.xindexes - if k in indexes: - index_dims = indexes.get_all_dims(k) - if any(d in dim for d in index_dims): - extra_index_dims.update(d for d in v.dims if d not in dim) - continue + """Check that ``coords`` dimension names and sizes do not conflict with + array ``shape`` and dimensions ``dims``. + + If a coordinate is associated with an index, the coordinate may have any + arbitrary dimension(s) as long as the index's dimensions (i.e., the union of + the dimensions of all coordinates associated with this index) intersects the + array dimensions. + + If a coordinate has no index, then its dimensions much match (or be a subset + of) the array dimensions. Scalar coordinates are also allowed. + + """ + indexes = coords.xindexes + skip_check_coord_name: set[Hashable] = set() + skip_check_dim_size: set[Hashable] = set() + + # check dimension names + for name, var in coords.items(): + if name in skip_check_coord_name: + continue + elif name in indexes: + index_dims = indexes.get_all_dims(name) + if any(d in dims for d in index_dims): + # can safely skip checking index's non-array dimensions + # and index's other coordinates since those must be all + # included in the dataarray so the index is not corrupted + skip_check_coord_name.update(indexes.get_all_coords(name)) + skip_check_dim_size.update(d for d in index_dims if d not in dims) + raise_error = False + else: + raise_error = True + else: + raise_error = any(d not in dims for d in var.dims) + if raise_error: raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " + f"coordinate {name} has dimensions {var.dims}, but these " "are not a subset of the DataArray " - f"dimensions {dim}" + f"dimensions {dims}" ) - for d, s in v.sizes.items(): - if d not in extra_index_dims and s != sizes[d]: + # check dimension sizes + sizes = dict(zip(dims, shape, strict=True)) + + for name, var in coords.items(): + for d, s in var.sizes.items(): + if d not in skip_check_dim_size and s != sizes[d]: raise ValueError( f"conflicting sizes for dimension {d!r}: " f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" + f"coordinate {name!r}" ) @@ -497,7 +522,7 @@ def __init__( if not isinstance(coords, Coordinates): coords = create_coords_with_default_indexes(coords) - _check_coords_dims(data.shape, coords, dims) + check_dataarray_coords(data.shape, coords, dims) indexes = dict(coords.xindexes) coords = {k: v.copy() for k, v in coords.variables.items()} From bc578d846bc9a6f38644c49091ea39320144920a Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 14 Mar 2025 11:48:46 +0100 Subject: [PATCH 10/11] assert invariants: skip check IndexVariable ... ... when check_default_indexes=False. --- xarray/testing/assertions.py | 22 +++++++++++++++------- xarray/tests/test_dataarray.py | 12 ++---------- xarray/tests/test_dataset.py | 6 +----- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 8a2dba9261f..ec7b4fdd410 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -330,10 +330,13 @@ def _assert_indexes_invariants_checks( k: type(v) for k, v in indexes.items() } - index_vars = { - k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable) - } - assert indexes.keys() <= index_vars, (set(indexes), index_vars) + if check_default: + index_vars = { + k + for k, v in possible_coord_variables.items() + if isinstance(v, IndexVariable) + } + assert indexes.keys() <= index_vars, (set(indexes), index_vars) # check pandas index wrappers vs. coordinate data adapters for k, index in indexes.items(): @@ -399,9 +402,14 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): da.dims, {k: v.dims for k, v in da._coords.items()}, ) - assert all( - isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,) - ), {k: type(v) for k, v in da._coords.items()} + + if check_default_indexes: + assert all( + isinstance(v, IndexVariable) + for (k, v) in da._coords.items() + if v.dims == (k,) + ), {k: type(v) for k, v in da._coords.items()} + for k, v in da._coords.items(): _assert_variable_invariants(v, k) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a44bff66409..1883b9eb407 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -546,11 +546,7 @@ class AnyIndex(Index): actual = DataArray([1.0, 2.0], coords=coords, dims="x") - # cannot use `assert_identical()` test utility function here yet - # (indexes invariant check is still based on IndexVariable, which - # doesn't work with AnyIndex coordinate variables here) - assert actual.coords.to_dataset().equals(coords.to_dataset()) - assert list(actual.coords.xindexes) == list(coords.xindexes) + assert_identical(actual.coords, coords, check_default_indexes=False) assert "x_bnds" not in actual.dims def test_equals_and_identical(self) -> None: @@ -1676,11 +1672,7 @@ class AnyIndex(Index): da = DataArray([1.0, 2.0], dims="x") actual = da.assign_coords(coords) - # cannot use `assert_identical()` test utility function here yet - # (indexes invariant check is still based on IndexVariable, which - # doesn't work with AnyIndex coordinate variables here) - assert actual.coords.to_dataset().equals(coords.to_dataset()) - assert list(actual.coords.xindexes) == list(coords.xindexes) + assert_identical(actual.coords, coords, check_default_indexes=False) assert "x_bnds" not in actual.dims def test_coords_alignment(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index a0fd7ebc6c2..b47599c7cd0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4224,11 +4224,7 @@ class AnyIndex(Index): ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords) actual = ds["foo"] - # cannot use `assert_identical()` test utility function here yet - # (indexes invariant check is still based on IndexVariable, which - # doesn't work with AnyIndex coordinate variables here) - assert actual.coords.to_dataset().equals(coords.to_dataset()) - assert list(actual.coords.xindexes) == list(coords.xindexes) + assert_identical(actual.coords, coords, check_default_indexes=False) assert "x_bnds" not in actual.dims def test_virtual_variables_default_coords(self) -> None: From 5e8093b070f2e7738193de86ee22abfd6f6c0e65 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 14 Mar 2025 14:45:34 +0100 Subject: [PATCH 11/11] formatting: improve perf of coord/data vars sections No need to have a mapping of DataArrays to format those coordinate and data variables sections. --- xarray/core/dataset_variables.py | 2 +- xarray/core/formatting.py | 6 ++++-- xarray/core/formatting_html.py | 11 ++++++----- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset_variables.py b/xarray/core/dataset_variables.py index 6521da61444..8f5b7442c7a 100644 --- a/xarray/core/dataset_variables.py +++ b/xarray/core/dataset_variables.py @@ -40,7 +40,7 @@ def __getitem__(self, key: Hashable) -> "DataArray": raise KeyError(key) def __repr__(self) -> str: - return formatting.data_vars_repr(self) + return formatting.data_vars_repr(self.variables) @property def variables(self) -> Mapping[Hashable, Variable]: diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 993cddf2b57..6302ee00210 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -429,7 +429,7 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): if col_width is None: col_width = _calculate_col_width(coords) return _mapping_repr( - coords, + coords.variables, title="Coordinates", summarizer=summarize_variable, expand_option_name="display_expand_coords", @@ -743,7 +743,9 @@ def dataset_repr(ds): if unindexed_dims_str: summary.append(unindexed_dims_str) - summary.append(data_vars_repr(ds.data_vars, col_width=col_width, max_rows=max_rows)) + summary.append( + data_vars_repr(ds.data_vars.variables, col_width=col_width, max_rows=max_rows) + ) display_default_indexes = _get_boolean_with_default( "display_default_indexes", False diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index eb9073cd869..69c128ae5fd 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -113,10 +113,11 @@ def summarize_variable(name, var, is_index=False, dtype=None) -> str: ) -def summarize_coords(variables) -> str: +def summarize_coords(coords) -> str: li_items = [] - for k, v in variables.items(): - li_content = summarize_variable(k, v, is_index=k in variables.xindexes) + indexes = coords.xindexes + for k, v in coords.variables.items(): + li_content = summarize_variable(k, v, is_index=k in indexes) li_items.append(f"
  • {li_content}
  • ") vars_li = "".join(li_items) @@ -339,7 +340,7 @@ def dataset_repr(ds) -> str: sections = [ dim_section(ds), coord_section(ds.coords), - datavar_section(ds.data_vars), + datavar_section(ds.data_vars.variables), index_section(_get_indexes_dict(ds.xindexes)), attr_section(ds.attrs), ] @@ -415,7 +416,7 @@ def datatree_node_repr(group_title: str, node: DataTree, show_inherited=False) - sections.append(inherited_coord_section(inherited_coords)) sections += [ - datavar_section(ds.data_vars), + datavar_section(ds.data_vars.variables), attr_section(ds.attrs), ]