Skip to content

Reduce reader memory consumption #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7b7bc7e
Save reader elements incrementally
marcovarrone Oct 28, 2024
d0fed4d
Consolidated metadata
marcovarrone Oct 28, 2024
d13f581
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
43f0d04
Write consolidated data only when output_path is set
marcovarrone Oct 28, 2024
75df7f2
Merge branch 'scverse:main' into low_memory
marcovarrone Mar 15, 2025
1d4c734
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2025
bbfa4f7
Use public write API
marcovarrone Mar 17, 2025
209a0c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2025
0d62dd9
Save cell_circles
marcovarrone Mar 17, 2025
dccccd1
Merge branch 'low_memory' of github.com:marcovarrone/spatialdata-io i…
marcovarrone Mar 17, 2025
612e2f7
Save region only if output_path is defined
marcovarrone Mar 17, 2025
1326299
improve cosmx loading
laudmt Mar 17, 2025
a9fb827
Load labels using dask instead of numpy
marcovarrone Mar 17, 2025
0109cae
fix
laudmt Mar 17, 2025
8841a73
Remove print
marcovarrone Mar 18, 2025
12180a3
Write metadata
marcovarrone Mar 18, 2025
9873def
improve ram
laudmt Mar 18, 2025
5988743
Merge pull request #1 from marcovarrone/feat/cosmx_perf
marcovarrone Mar 18, 2025
b4ff59b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2025
20c3fb3
Remove unwanted README change
marcovarrone Mar 19, 2025
4af365f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
9d5feb1
Remove unwanted README change pt.2
marcovarrone Mar 19, 2025
f97f1b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ Marconato, L., Palla, G., Yamauchi, K.A. et al. SpatialData: an open and univers
[link-docs]: https://spatialdata.scverse.org/projects/io/en/latest/
[link-api]: https://spatialdata.scverse.org/projects/io/en/latest/api.html
[link-cli]: https://spatialdata.scverse.org/projects/io/en/latest/cli.html

[//]: # (numfocus-fiscal-sponsor-attribution)
[//]: # "numfocus-fiscal-sponsor-attribution"

spatialdata-io is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
If you like scverse® and want to support our mission, please consider making a tax-deductible [donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs.
Expand Down
53 changes: 48 additions & 5 deletions src/spatialdata_io/readers/cosmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dask_image.imread import imread
from scipy.sparse import csr_matrix
from skimage.transform import estimate_transform
from spatialdata import SpatialData
from spatialdata import SpatialData, read_zarr
from spatialdata._logging import logger
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel
from spatialdata.transformations.transformations import Affine, Identity
Expand All @@ -34,6 +34,7 @@ def cosmx(
transcripts: bool = True,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
output_path: str | Path | None = None,
) -> SpatialData:
"""
Read *Cosmx Nanostring* data.
Expand Down Expand Up @@ -62,12 +63,20 @@ def cosmx(
Keyword arguments passed to :func:`dask_image.imread.imread`.
image_models_kwargs
Keyword arguments passed to :class:`spatialdata.models.Image2DModel`.
output_path
Path where the output will be saved. If ``None``, the output will not be saved.

Returns
-------
:class:`spatialdata.SpatialData`
"""
path = Path(path)
output_path = Path(output_path) if output_path is not None else None
sdata = SpatialData()

# If output path is provided, save the empty SpatialData object to create directories and hierarchy
if output_path is not None:
sdata.write(output_path, overwrite=True)

# tries to infer dataset_id from the name of the counts file
if dataset_id is None:
Expand Down Expand Up @@ -151,6 +160,16 @@ def cosmx(
inplace=True,
)

# Add table to SpatialData object, write it and delete temporary objects to save memory
sdata.tables["table"] = table
if output_path is not None:
sdata.write_element(element_name="table")
del adata
del table
del sdata.tables["table"]
del counts
del obs

# prepare to read images and labels
file_extensions = (".jpg", ".png", ".jpeg", ".tif", ".tiff")
pat = re.compile(r".*_F(\d+)")
Expand Down Expand Up @@ -195,7 +214,14 @@ def cosmx(
rgb=None,
**image_models_kwargs,
)
images[f"{fov}_image"] = parsed_im
image_name = f"{fov}_image"
images[image_name] = parsed_im
if output_path is not None:
sdata.images[image_name] = parsed_im
sdata.write_element(element_name=image_name)
del parsed_im
del images[image_name]
del sdata.images[image_name]
else:
logger.warning(f"FOV {fov} not found in counts file. Skipping image {fname}.")

Expand All @@ -218,7 +244,14 @@ def cosmx(
dims=("y", "x"),
**image_models_kwargs,
)
labels[f"{fov}_labels"] = parsed_la
label_name = f"{fov}_labels"
labels[label_name] = parsed_la
if output_path is not None:
sdata.labels[label_name] = parsed_la
sdata.write_element(element_name=label_name)
del parsed_la
del labels[label_name]
del sdata.labels[label_name]
else:
logger.warning(f"FOV {fov} not found in counts file. Skipping labels {fname}.")

Expand Down Expand Up @@ -256,6 +289,8 @@ def cosmx(
transcripts_data = pd.read_csv(path / transcripts_file, header=0)
transcripts_data.to_parquet(Path(tmpdir) / "transcripts.parquet")
print("done")
if output_path is not None:
del transcripts_data

ptable = pq.read_table(Path(tmpdir) / "transcripts.parquet")
for fov in fovs_counts:
Expand All @@ -265,7 +300,8 @@ def cosmx(
# we rename z because we want to treat the data as 2d
sub_table.rename(columns={"z": "z_raw"}, inplace=True)
if len(sub_table) > 0:
points[f"{fov}_points"] = PointsModel.parse(
point_name = f"{fov}_points"
points[point_name] = PointsModel.parse(
sub_table,
coordinates={"x": CosmxKeys.X_LOCAL_TRANSCRIPT, "y": CosmxKeys.Y_LOCAL_TRANSCRIPT},
feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT,
Expand All @@ -276,6 +312,11 @@ def cosmx(
"global_only_labels": aff,
},
)
if output_path is not None:
sdata.points[point_name] = points[point_name]
sdata.write_element(element_name=point_name)
del points[point_name]
del sdata.points[point_name]

# TODO: what to do with fov file?
# if fov_file is not None:
Expand All @@ -286,5 +327,7 @@ def cosmx(
# except KeyError:
# logg.warning(f"FOV `{str(fov)}` does not exist, skipping it.")
# continue

if output_path is not None:
sdata.write_consolidated_metadata()
return read_zarr(output_path)
return SpatialData(images=images, labels=labels, points=points, table=table)
80 changes: 58 additions & 22 deletions src/spatialdata_io/readers/xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from joblib import Parallel, delayed
from pyarrow import Table
from shapely import Polygon
from spatialdata import SpatialData
from spatialdata import SpatialData, read_zarr
from spatialdata._core.query.relational_query import get_element_instances
from spatialdata._types import ArrayLike
from spatialdata.models import (
Expand Down Expand Up @@ -67,6 +67,7 @@ def xenium(
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
labels_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
output_path: Path | None = None,
) -> SpatialData:
"""
Read a *10X Genomics Xenium* dataset into a SpatialData object.
Expand Down Expand Up @@ -123,6 +124,8 @@ def xenium(
Keyword arguments to pass to the image models.
labels_models_kwargs
Keyword arguments to pass to the labels models.
output_path
Path to directly write every element to a zarr file as soon as it is read. This can decrease the memory requirement.

Returns
-------
Expand Down Expand Up @@ -159,6 +162,8 @@ def xenium(
image_models_kwargs, labels_models_kwargs
)
path = Path(path)
output_path = Path(output_path) if output_path is not None else None

with open(path / XeniumKeys.XENIUM_SPECS) as f:
specs = json.load(f)
# to trigger the warning if the version cannot be parsed
Expand Down Expand Up @@ -203,11 +208,10 @@ def xenium(
table.obs[XeniumKeys.Z_LEVEL] = cell_summary_table[XeniumKeys.Z_LEVEL]
table.obs[XeniumKeys.NUCLEUS_COUNT] = cell_summary_table[XeniumKeys.NUCLEUS_COUNT]

polygons = {}
labels = {}
tables = {}
points = {}
images = {}
sdata = SpatialData()

if output_path is not None:
sdata.write(output_path)

# From the public release notes here:
# https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/release-notes/release-notes-for-xoa
Expand All @@ -216,23 +220,29 @@ def xenium(
# nuclei to cells. Therefore for the moment we only link the table to the cell labels, and not to the nucleus
# labels.
if nucleus_labels:
labels["nucleus_labels"], _ = _get_labels_and_indices_mapping(
sdata.labels["nucleus_labels"], _ = _get_labels_and_indices_mapping(
path,
XeniumKeys.CELLS_ZARR,
specs,
mask_index=0,
labels_name="nucleus_labels",
labels_models_kwargs=labels_models_kwargs,
)
if output_path is not None:
sdata.write_element(element_name="nucleus_labels")
del sdata.labels["nucleus_labels"]
if cells_labels:
labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
sdata.labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
path,
XeniumKeys.CELLS_ZARR,
specs,
mask_index=1,
labels_name="cell_labels",
labels_models_kwargs=labels_models_kwargs,
)
if output_path is not None:
sdata.write_element(element_name="cell_labels")
del sdata.labels["cell_labels"]
if cell_labels_indices_mapping is not None and table is not None:
if not pd.DataFrame.equals(cell_labels_indices_mapping["cell_id"], table.obs[str(XeniumKeys.CELL_ID)]):
warnings.warn(
Expand All @@ -248,41 +258,53 @@ def xenium(
table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"

if nucleus_boundaries:
polygons["nucleus_boundaries"] = _get_polygons(
sdata.shapes["nucleus_boundaries"] = _get_polygons(
path,
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
specs,
n_jobs,
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
)

if output_path is not None:
sdata.write_element(element_name="nucleus_boundaries")
del sdata.shapes["nucleus_boundaries"]
if cells_boundaries:
polygons["cell_boundaries"] = _get_polygons(
sdata.shapes["cell_boundaries"] = _get_polygons(
path,
XeniumKeys.CELL_BOUNDARIES_FILE,
specs,
n_jobs,
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
)

if output_path is not None:
sdata.write_element(element_name="cell_boundaries")
del sdata.shapes["cell_boundaries"]
if transcripts:
points["transcripts"] = _get_points(path, specs)

sdata.points["transcripts"] = _get_points(path, specs)
if output_path is not None:
sdata.write_element(element_name="transcripts")
del sdata.points["transcripts"]
if version is None or version < packaging.version.parse("2.0.0"):
if morphology_mip:
images["morphology_mip"] = _get_images(
sdata.images["morphology_mip"] = _get_images(
path,
XeniumKeys.MORPHOLOGY_MIP_FILE,
imread_kwargs,
image_models_kwargs,
)
if output_path is not None:
sdata.write_element(element_name="morphology_mip")
del sdata.images["morphology_mip"]
if morphology_focus:
images["morphology_focus"] = _get_images(
sdata.images["morphology_focus"] = _get_images(
path,
XeniumKeys.MORPHOLOGY_FOCUS_FILE,
imread_kwargs,
image_models_kwargs,
)
if output_path is not None:
sdata.write_element(element_name="morphology_focus")
del sdata.images["morphology_focus"]
else:
if morphology_focus:
morphology_focus_dir = path / XeniumKeys.MORPHOLOGY_FOCUS_DIR
Expand Down Expand Up @@ -328,28 +350,42 @@ def filter(self, record: logging.LogRecord) -> bool:
"c_coords" not in image_models_kwargs
), "The channel names for the morphology focus images are handled internally"
image_models_kwargs["c_coords"] = list(channel_names.values())
images["morphology_focus"] = _get_images(
sdata.images["morphology_focus"] = _get_images(
morphology_focus_dir,
XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.format(0),
imread_kwargs,
image_models_kwargs,
)
del image_models_kwargs["c_coords"]
if output_path is not None:
sdata.write_element(element_name="morphology_focus")
del sdata.images["morphology_focus"]
logger.removeFilter(IgnoreSpecificMessage())

if table is not None:
tables["table"] = table
sdata.tables["table"] = table
if output_path is not None:
sdata.write_element(element_name="table")
del sdata.tables["table"]

elements_dict = {"images": images, "labels": labels, "points": points, "tables": tables, "shapes": polygons}
if cells_as_circles:
elements_dict["shapes"][specs["region"]] = circles
sdata = SpatialData(**elements_dict)
sdata.shapes[specs["region"]] = circles
if output_path is not None:
sdata.write_element(element_name=specs["region"])
del sdata.shapes[specs["region"]]

# find and add additional aligned images
if aligned_images:
extra_images = _add_aligned_images(path, imread_kwargs, image_models_kwargs)
for key, value in extra_images.items():
sdata.images[key] = value
if output_path is not None:
sdata.write_element(element_name=key)
del sdata.images[key]

if output_path is not None:
sdata.write_consolidated_metadata()
sdata = read_zarr(output_path)

return sdata

Expand Down Expand Up @@ -415,7 +451,7 @@ def _get_labels_and_indices_mapping(

with zarr.open(str(tmpdir), mode="r") as z:
# get the labels
masks = z["masks"][f"{mask_index}"][...]
masks = da.from_array(z["masks"][f"{mask_index}"])
labels = Labels2DModel.parse(
masks, dims=("y", "x"), transformations={"global": Identity()}, **labels_models_kwargs
)
Expand Down
Loading