From fc09cc077ab450da7623bf2dc4a3c426d7c0b420 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 9 Jul 2024 12:38:16 +0200 Subject: [PATCH 1/5] allow filtering by ids --- src/spatialdata/_core/spatialdata.py | 55 +++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 41de30e0..af2a5b64 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -3,7 +3,7 @@ import hashlib import os import warnings -from collections.abc import Generator +from collections.abc import Generator, Iterable from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -2143,6 +2143,59 @@ def __delitem__(self, key: str) -> None: element_type, _, _ = self._find_element(key) getattr(self, element_type).__delitem__(key) + def filter_elements_by_instances( + self, + element_names: Iterable[str], + instances: Iterable[int | str], + region_names: Iterable[str] | str | None = None, + ) -> dict[str, DaskDataFrame | GeoDataFrame | AnnData]: + """ + Filter elements to contain only certain instances. + + This filters both SpatialElements (points and shapes) + as well as tables to only contain certain IDs. In case of tables + the instance key column of table.obs will be filtered on and not + table.obs.index. Filtering labels by ID is currently not supported + as this is an expensive operation. Should you require this + please open an issue on github.com/scverse/spatialdata. Lastly, + tables not annotating an element cannot be filtered. + + element_names: + Name of either points, shapes or table elements within the Spatialdata + object. + instances: + The instance IDs to filter the elements on. + region_names: + If filtering instances in a table, indicate the region_names (the names of the SpatialElement) for + which you want to filter the instances of the table. If not specified, the table instances for all regions + annotated by the table will be filtered by the given instances. + """ + element_dict = {} + element_names = [element_names] if isinstance(element_names, str) else list(element_names) + for element_name in element_names: + element = self.get(element_name) + if element is not None: + if (model := get_model(element)) == PointsModel: + instance_key = element.attrs[PointsModel.ATTRS_KEY][PointsModel.INSTANCE_KEY] + element_dict[element_name] = element[element[instance_key].isin(instances)] + elif model == ShapesModel: + element_dict[element_name] = element[element.index.isin(instances)] + elif model == TableModel: + instance_key = element.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + region_key = element.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + if region_names: + region_names = [region_names] if isinstance(region_names, str) else region_names + element = element[element.obs[region_key].isin(region_names)] + regions = element.obs[region_key].cat.categories.tolist() + element_dict[element_name] = element[element.obs[instance_key].isin(instances)].copy() + element_dict[element_name].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = regions + TableModel().validate(element_dict[element_name]) + else: + raise TypeError(f"`{model}` is not a valid model for filtering of instances.") + else: + raise KeyError(f"`{element_name}` is not an element in the SpatialData object.") + return element_dict + class QueryManager: """Perform queries on SpatialData objects.""" From 122aa5f22ecbd0cc8c3193dd0ff55a4d25dec1f7 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 28 Jan 2025 14:53:17 +0100 Subject: [PATCH 2/5] wip match_table_to_sdata() --- .../_core/query/relational_query.py | 28 +++++++++++++++++++ src/spatialdata/_core/spatialdata.py | 17 ++++------- src/spatialdata/models/models.py | 5 +++- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 4154c5f2..c838f4ba 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -770,6 +770,34 @@ def match_element_to_table( return element_dict, table +def match_sdata_to_table( + sdata: SpatialData, + table: AnnData, + table_name: str, +) -> SpatialData: + """ + Filter the elements of a SpatialData object to match only the rows present in the table. + + Parameters + ---------- + sdata + SpatialData object containing all the elements and tables. + table + The table to join with the spatial elements. + table_name + The name of the table in the returned SpatialData object. + + Notes + ----- + Eventual rows in the table that do not annoate any instance are preserved. + """ + annotated_regions = SpatialData.get_annotated_regions(table) + filtered_elements, filtered_table = join_spatialelement_table( + sdata, spatial_element_names=annotated_regions, table=table, how="right" + ) + return SpatialData.init_from_elements(filtered_elements | {table_name: table}) + + @dataclass class _ValueOrigin: origin: str diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index f42b4df8..1de457b0 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2417,7 +2417,7 @@ def __delitem__(self, key: str) -> None: element_type, _, _ = self._find_element(key) getattr(self, element_type).__delitem__(key) - def filter_elements_by_instances( + def filter( self, element_names: Iterable[str], instances: Iterable[int | str], @@ -2426,12 +2426,10 @@ def filter_elements_by_instances( """ Filter elements to contain only certain instances. - This filters both SpatialElements (points and shapes) - as well as tables to only contain certain IDs. In case of tables - the instance key column of table.obs will be filtered on and not - table.obs.index. Filtering labels by ID is currently not supported - as this is an expensive operation. Should you require this - please open an issue on github.com/scverse/spatialdata. Lastly, + This filters both SpatialElements (points and shapes) as well as tables to only contain certain IDs. In case of + tables the instance key column of table.obs will be filtered on and not table.obs.index; in case of Points and + Shapes, the index will be used. Filtering labels by ID is currently not supported as this is an expensive + operation. Should you require this please open an issue on github.com/scverse/spatialdata. Lastly, tables not annotating an element cannot be filtered. element_names: @@ -2449,10 +2447,7 @@ def filter_elements_by_instances( for element_name in element_names: element = self.get(element_name) if element is not None: - if (model := get_model(element)) == PointsModel: - instance_key = element.attrs[PointsModel.ATTRS_KEY][PointsModel.INSTANCE_KEY] - element_dict[element_name] = element[element[instance_key].isin(instances)] - elif model == ShapesModel: + if (model := get_model(element)) in [PointsModel, ShapesModel]: element_dict[element_name] = element[element.index.isin(instances)] elif model == TableModel: instance_key = element.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index b2663e56..1335a483 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -987,6 +987,7 @@ def parse( region: str | list[str] | None = None, region_key: str | None = None, instance_key: str | None = None, + overwrite_metadata: bool = False, ) -> AnnData: """ Parse the :class:`anndata.AnnData` to be compatible with the model. @@ -1001,6 +1002,8 @@ def parse( Key in `adata.obs` that specifies the region. instance_key Key in `adata.obs` that specifies the instance. + overwrite_metadata + If `True`, the `region`, `region_key` and `instance_key` metadata will be overwritten. Returns ------- @@ -1011,7 +1014,7 @@ def parse( n_args = sum([region is not None, region_key is not None, instance_key is not None]) if n_args == 0: return adata - if n_args > 0: + if n_args > 0 and not overwrite_metadata: if cls.ATTRS_KEY in adata.uns: raise ValueError( f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" From 829a4cddad654dabac19d2816a4f7d643c1e6ea8 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 2 Feb 2025 17:30:18 +0100 Subject: [PATCH 3/5] add match_table_to_sdata(); improve table model, get_annotated_regions() --- docs/api/operations.md | 1 + docs/tutorials/notebooks | 2 +- src/spatialdata/__init__.py | 2 + .../_core/query/relational_query.py | 23 +++- src/spatialdata/_core/spatialdata.py | 7 +- src/spatialdata/models/models.py | 50 +++++-- tests/core/query/test_relational_query.py | 8 ++ ...t_relational_query_match_sdata_to_table.py | 122 ++++++++++++++++++ tests/models/test_models.py | 43 +++++- 9 files changed, 232 insertions(+), 26 deletions(-) create mode 100644 tests/core/query/test_relational_query_match_sdata_to_table.py diff --git a/docs/api/operations.md b/docs/api/operations.md index f82331c7..937b8dbc 100644 --- a/docs/api/operations.md +++ b/docs/api/operations.md @@ -14,6 +14,7 @@ Operations on `SpatialData` objects. .. autofunction:: join_spatialelement_table .. autofunction:: match_element_to_table .. autofunction:: match_table_to_element +.. autofunction:: match_sdata_to_table .. autofunction:: concatenate .. autofunction:: transform .. autofunction:: rasterize diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 02859d31..2fa87d5a 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 02859d31a0df0245d36af905d3eb3068a9965445 +Subproject commit 2fa87d5a629252dd8b85430ed9d2f425a8b062ed diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 59866059..16f659d6 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -40,6 +40,7 @@ "join_spatialelement_table", "match_element_to_table", "match_table_to_element", + "match_sdata_to_table", "SpatialData", "get_extent", "get_centroids", @@ -72,6 +73,7 @@ get_values, join_spatialelement_table, match_element_to_table, + match_sdata_to_table, match_table_to_element, ) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index c838f4ba..a2678733 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -774,6 +774,7 @@ def match_sdata_to_table( sdata: SpatialData, table: AnnData, table_name: str, + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", ) -> SpatialData: """ Filter the elements of a SpatialData object to match only the rows present in the table. @@ -783,19 +784,27 @@ def match_sdata_to_table( sdata SpatialData object containing all the elements and tables. table - The table to join with the spatial elements. + The table to join with the spatial elements. Has precedence over `table_name`. table_name - The name of the table in the returned SpatialData object. + The name of the table to join with the SpatialData object if `table` is not provided. Also, `table_name` is used + to name the table in the returned `SpatialData` object. + how + The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". - Notes - ----- - Eventual rows in the table that do not annoate any instance are preserved. """ + _, region_key, instance_key = get_table_keys(table) annotated_regions = SpatialData.get_annotated_regions(table) filtered_elements, filtered_table = join_spatialelement_table( - sdata, spatial_element_names=annotated_regions, table=table, how="right" + sdata, spatial_element_names=annotated_regions, table=table, how=how + ) + filtered_table = TableModel.parse( + filtered_table, + region=annotated_regions, + region_key=region_key, + instance_key=instance_key, + overwrite_metadata=True, ) - return SpatialData.init_from_elements(filtered_elements | {table_name: table}) + return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table}) @dataclass diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 1de457b0..64bc70d2 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -263,7 +263,7 @@ def from_elements_dict( return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs) @staticmethod - def get_annotated_regions(table: AnnData) -> str | list[str]: + def get_annotated_regions(table: AnnData) -> list[str]: """ Get the regions annotated by a table. @@ -276,8 +276,9 @@ def get_annotated_regions(table: AnnData) -> str | list[str]: ------- The annotated regions. """ - regions, _, _ = get_table_keys(table) - return regions + from spatialdata.models.models import _get_region_metadata_from_region_key_column + + return _get_region_metadata_from_region_key_column(table) @staticmethod def get_region_key_column(table: AnnData) -> pd.Series: diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 1335a483..d3622a3e 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -1013,31 +1013,38 @@ def parse( # either all live in adata.uns or all be passed in as argument n_args = sum([region is not None, region_key is not None, instance_key is not None]) if n_args == 0: - return adata - if n_args > 0 and not overwrite_metadata: - if cls.ATTRS_KEY in adata.uns: - raise ValueError( - f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" - f"as argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set." - ) - elif cls.ATTRS_KEY in adata.uns: + if cls.ATTRS_KEY not in adata.uns: + # table not annotating any element + return adata attr = adata.uns[cls.ATTRS_KEY] region = attr[cls.REGION_KEY] region_key = attr[cls.REGION_KEY_KEY] instance_key = attr[cls.INSTANCE_KEY] + elif n_args > 0 and not overwrite_metadata and cls.ATTRS_KEY in adata.uns: + raise ValueError( + f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" + f" argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set." + ) + + if cls.ATTRS_KEY not in adata.uns: + adata.uns[cls.ATTRS_KEY] = {} + if region is None: + raise ValueError(f"`{cls.REGION_KEY}` must be provided.") if region_key is None: raise ValueError(f"`{cls.REGION_KEY_KEY}` must be provided.") + if instance_key is None: + raise ValueError("`instance_key` must be provided.") + if isinstance(region, np.ndarray): region = region.tolist() - if region is None: - raise ValueError(f"`{cls.REGION_KEY}` must be provided.") region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") - if instance_key is None: - raise ValueError("`instance_key` must be provided.") + adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region + adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key + adata.uns[cls.ATTRS_KEY][cls.INSTANCE_KEY] = instance_key # note! this is an expensive check and therefore we skip it during validation # https://github.com/scverse/spatialdata/issues/715 @@ -1050,8 +1057,6 @@ def parse( f"Instance key column for region(s) `{', '.join(not_unique)}` does not contain only unique values" ) - attr = {"region": region, "region_key": region_key, "instance_key": instance_key} - adata.uns[cls.ATTRS_KEY] = attr cls().validate(adata) return convert_region_column_to_categorical(adata) @@ -1132,3 +1137,20 @@ def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]: raise ValueError( "No spatialdata_attrs key found in table.uns, therefore, no table keys found. Please parse the table." ) + + +def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]: + _, region_key, instance_key = get_table_keys(table) + region_key_column = table.obs[region_key] + if not isinstance(region_key_column.dtype, CategoricalDtype): + warnings.warn( + f"The region key column `{region_key}` is not of type `pd.Categorical`. Consider casting it to " + f"improve performance.", + UserWarning, + stacklevel=2, + ) + annotated_regions = region_key_column.unique().tolist() + else: + annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist() + assert isinstance(annotated_regions, list) + return annotated_regions diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 877349d4..3038fdb4 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -901,3 +901,11 @@ def test_get_element_annotators(full_sdata): full_sdata.tables["another_table"] = another_table names = get_element_annotators(full_sdata, "labels2d") assert names == {"another_table", "table"} + + +# def test_match_table_to_element(sdata_query_aggregation): +# matched_table = match_table_to_element(sdata=sdata_query_aggregation, element_name="values_circles") +# arr = np.array(list(reversed(sdata_query_aggregation["values_circles"].index))) +# sdata_query_aggregation["values_circles"].index = arr +# matched_table_reversed = match_table_to_element(sdata=sdata_query_aggregation, element_name="values_circles") +# assert matched_table.obs.index.tolist() == list(reversed(matched_table_reversed.obs.index.tolist())) diff --git a/tests/core/query/test_relational_query_match_sdata_to_table.py b/tests/core/query/test_relational_query_match_sdata_to_table.py new file mode 100644 index 00000000..97c7781e --- /dev/null +++ b/tests/core/query/test_relational_query_match_sdata_to_table.py @@ -0,0 +1,122 @@ +import pytest + +from spatialdata import concatenate, match_sdata_to_table +from spatialdata.datasets import blobs_annotating_element + +# constructing the example data; let's use a global variable as we can reuse the same object for all the tests +# without having to recreate it +sdata1 = blobs_annotating_element("blobs_polygons") +sdata2 = blobs_annotating_element("blobs_polygons") +sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True) +sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0])) + + +def test_match_sdata_to_table_filter_specific_instances(): + """ + Filter to keep only specific instances. Note that it works even when the table annotates multiple elements. + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][sdata["table"].obs.instance_id.isin([1, 2])], + table_name="table", + ) + assert len(matched["table"]) == 4 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" in matched + + +def test_match_sdata_to_table_filter_specific_instances_element(): + """ + Filter to keep only specific instances, in a specific element. + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][ + sdata["table"].obs.instance_id.isin([1, 2]) & (sdata["table"].obs.region == "blobs_polygons-sdata1") + ], + table_name="table", + ) + assert len(matched["table"]) == 2 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" not in matched + + +def test_match_sdata_to_table_filter_by_threshold(): + """ + Filter by a threshold on a value column, in a specific element. + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][sdata["table"].obs.query('value < 5 and region == "blobs_polygons-sdata1"').index], + table_name="table", + ) + assert len(matched["table"]) == 5 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" not in matched + + +def test_match_sdata_to_table_subset_certain_obs(): + """ + Subset to certain obs (we could also subset to certain var or layer). + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][[0, 1, 2, 3]], + table_name="table", + ) + assert len(matched["table"]) == 4 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" not in matched + + +def test_match_sdata_to_table_shapes_and_points(): + """ + The function works both for shapes (examples above) and points. + Changes the target of the table to labels. + """ + sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "points")) + sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") + sdata.set_table_annotates_spatialelement( + table_name="table", + region=["blobs_points-sdata1", "blobs_points-sdata2"], + region_key="region", + instance_key="instance_id", + ) + + matched = match_sdata_to_table( + sdata, + table=sdata["table"], + table_name="table", + ) + + assert len(matched["table"]) == 10 + assert "blobs_points-sdata1" in matched + assert "blobs_points-sdata2" in matched + assert "blobs_polygons-sdata1" not in matched + + +def test_match_sdata_to_table_match_labels_error(): + """ + match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the + join. + """ + sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("points", "labels")) + sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") + sdata.set_table_annotates_spatialelement( + table_name="table", + region=["blobs_labels-sdata1", "blobs_labels-sdata2"], + region_key="region", + instance_key="instance_id", + ) + + with pytest.warns(UserWarning, match="Element type `labels` not supported for 'right' join. Skipping "): + matched = match_sdata_to_table( + sdata, + table=sdata["table"], + table_name="table", + ) + + assert len(matched["table"]) == 10 + assert "blobs_labels-sdata1" in matched + assert "blobs_labels-sdata2" in matched + assert "blobs_points-sdata1" not in matched diff --git a/tests/models/test_models.py b/tests/models/test_models.py index f5889f96..441d5831 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -27,6 +27,7 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._core.validation import ValidationError from spatialdata._types import ArrayLike +from spatialdata.models import get_table_keys from spatialdata.models._utils import ( force_2d, points_dask_dataframe_to_geopandas, @@ -367,6 +368,46 @@ def test_table_model( assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY] assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region + # error when trying to parse a table by specifying region, region_key, instance_key, but these keys are + # already set + with pytest.raises(ValueError, match=" has already been set"): + _ = TableModel.parse(adata, region=region, region_key=region_key, instance_key="A") + + # error when region is missing + with pytest.raises(ValueError, match="`region` must be provided"): + _ = TableModel.parse(adata, region_key=region_key, instance_key="A", overwrite_metadata=True) + + # error when region_key is missing + with pytest.raises(ValueError, match="`region_key` must be provided"): + _ = TableModel.parse(adata, region=region, instance_key="A", overwrite_metadata=True) + + # error when instance_key is missing + with pytest.raises(ValueError, match="`instance_key` must be provided"): + _ = TableModel.parse(adata, region=region, region_key=region_key, overwrite_metadata=True) + + # we try to overwrite, but the values in the `region_key` column do not match the expected `region` values + with pytest.raises(ValueError, match="values do not match with `region` values"): + _ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True) + + # we correctly overwrite; here we check that the metadata is updated + region_, region_key_, instance_key_ = get_table_keys(table) + assert region_ == region + assert region_key_ == region_key + assert instance_key_ == "A" + + # let's fix the region_key column + table.obs["B"] = ["element"] * len(table) + _ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True) + + region_, region_key_, instance_key_ = get_table_keys(table) + assert region_ == "element" + assert region_key_ == "B" + assert instance_key_ == "C" + + # we can parse a table when no metadata is present (i.e. the table does not annotate any element) + del table.uns[TableModel.ATTRS_KEY] + _ = TableModel.parse(table) + @pytest.mark.parametrize( "name", [ @@ -410,7 +451,7 @@ def test_table_instance_key_values_not_unique(self, model: TableModel, region: s adata.obs["A"] = [1] * 10 with pytest.raises(ValueError, match=re.escape("Instance key column for region(s) `sample_1, sample_2`")): - model.parse(adata, region=region, region_key=region_key, instance_key="A") + model.parse(adata, region=region, region_key=region_key, instance_key="A", overwrite_metadata=True) @pytest.mark.parametrize( "key", From 0d963a09026ff43c4a7dfafd2a30657145a02495 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 2 Feb 2025 17:51:32 +0100 Subject: [PATCH 4/5] remove filter() --- src/spatialdata/_core/spatialdata.py | 50 +--------------------------- 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 64bc70d2..6dc4537f 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -4,7 +4,7 @@ import json import os import warnings -from collections.abc import Generator, Iterable, Mapping +from collections.abc import Generator, Mapping from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -2418,54 +2418,6 @@ def __delitem__(self, key: str) -> None: element_type, _, _ = self._find_element(key) getattr(self, element_type).__delitem__(key) - def filter( - self, - element_names: Iterable[str], - instances: Iterable[int | str], - region_names: Iterable[str] | str | None = None, - ) -> dict[str, DaskDataFrame | GeoDataFrame | AnnData]: - """ - Filter elements to contain only certain instances. - - This filters both SpatialElements (points and shapes) as well as tables to only contain certain IDs. In case of - tables the instance key column of table.obs will be filtered on and not table.obs.index; in case of Points and - Shapes, the index will be used. Filtering labels by ID is currently not supported as this is an expensive - operation. Should you require this please open an issue on github.com/scverse/spatialdata. Lastly, - tables not annotating an element cannot be filtered. - - element_names: - Name of either points, shapes or table elements within the Spatialdata - object. - instances: - The instance IDs to filter the elements on. - region_names: - If filtering instances in a table, indicate the region_names (the names of the SpatialElement) for - which you want to filter the instances of the table. If not specified, the table instances for all regions - annotated by the table will be filtered by the given instances. - """ - element_dict = {} - element_names = [element_names] if isinstance(element_names, str) else list(element_names) - for element_name in element_names: - element = self.get(element_name) - if element is not None: - if (model := get_model(element)) in [PointsModel, ShapesModel]: - element_dict[element_name] = element[element.index.isin(instances)] - elif model == TableModel: - instance_key = element.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] - region_key = element.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] - if region_names: - region_names = [region_names] if isinstance(region_names, str) else region_names - element = element[element.obs[region_key].isin(region_names)] - regions = element.obs[region_key].cat.categories.tolist() - element_dict[element_name] = element[element.obs[instance_key].isin(instances)].copy() - element_dict[element_name].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = regions - TableModel().validate(element_dict[element_name]) - else: - raise TypeError(f"`{model}` is not a valid model for filtering of instances.") - else: - raise KeyError(f"`{element_name}` is not an element in the SpatialData object.") - return element_dict - @property def attrs(self) -> dict[Any, Any]: """ From d79e5335c60d7f55d1e991a0f3be4c7893fe3be9 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 2 Feb 2025 18:15:58 +0100 Subject: [PATCH 5/5] add default argument for match_sdata_to_table(); fix tests --- .../_core/query/relational_query.py | 4 +- ...t_relational_query_match_sdata_to_table.py | 38 +++++++++++++++---- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index e0255365..c21e9e3a 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -775,8 +775,8 @@ def match_element_to_table( def match_sdata_to_table( sdata: SpatialData, - table: AnnData, table_name: str, + table: AnnData | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", ) -> SpatialData: """ @@ -795,6 +795,8 @@ def match_sdata_to_table( The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". """ + if table is None: + table = sdata[table_name] _, region_key, instance_key = get_table_keys(table) annotated_regions = SpatialData.get_annotated_regions(table) filtered_elements, filtered_table = join_spatialelement_table( diff --git a/tests/core/query/test_relational_query_match_sdata_to_table.py b/tests/core/query/test_relational_query_match_sdata_to_table.py index 97c7781e..6d4fcadf 100644 --- a/tests/core/query/test_relational_query_match_sdata_to_table.py +++ b/tests/core/query/test_relational_query_match_sdata_to_table.py @@ -1,14 +1,20 @@ import pytest -from spatialdata import concatenate, match_sdata_to_table +from spatialdata import SpatialData, concatenate, match_sdata_to_table from spatialdata.datasets import blobs_annotating_element -# constructing the example data; let's use a global variable as we can reuse the same object for all the tests + +def _make_test_data() -> SpatialData: + sdata1 = blobs_annotating_element("blobs_polygons") + sdata2 = blobs_annotating_element("blobs_polygons") + sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True) + sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0])) + return sdata + + +# constructing the example data; let's use a global variable as we can reuse the same object on most tests # without having to recreate it -sdata1 = blobs_annotating_element("blobs_polygons") -sdata2 = blobs_annotating_element("blobs_polygons") -sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True) -sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0])) +sdata = _make_test_data() def test_match_sdata_to_table_filter_specific_instances(): @@ -74,6 +80,7 @@ def test_match_sdata_to_table_shapes_and_points(): The function works both for shapes (examples above) and points. Changes the target of the table to labels. """ + sdata = _make_test_data() sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "points")) sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") sdata.set_table_annotates_spatialelement( @@ -100,7 +107,8 @@ def test_match_sdata_to_table_match_labels_error(): match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the join. """ - sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("points", "labels")) + sdata = _make_test_data() + sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels")) sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") sdata.set_table_annotates_spatialelement( table_name="table", @@ -109,7 +117,10 @@ def test_match_sdata_to_table_match_labels_error(): instance_key="instance_id", ) - with pytest.warns(UserWarning, match="Element type `labels` not supported for 'right' join. Skipping "): + with pytest.warns( + UserWarning, + match="Element type `labels` not supported for 'right' join. Skipping ", + ): matched = match_sdata_to_table( sdata, table=sdata["table"], @@ -120,3 +131,14 @@ def test_match_sdata_to_table_match_labels_error(): assert "blobs_labels-sdata1" in matched assert "blobs_labels-sdata2" in matched assert "blobs_points-sdata1" not in matched + + +def test_match_sdata_to_table_no_table_argument(): + """ + If no table argument is passed, the table_name argument will be used to match the table. + """ + matched = match_sdata_to_table(sdata=sdata, table_name="table") + + assert len(matched["table"]) == 10 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" in matched