diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3ebbfe2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,54 @@ +ci: + autoupdate_schedule: quarterly +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-toml + - id: check-yaml + - id: detect-private-key + - id: end-of-file-fixer + - id: mixed-line-ending + - id: trailing-whitespace + - id: no-commit-to-branch + args: [--branch=main] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.9 + hooks: + - id: ruff + args: [--fix] + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.10.0 + hooks: + - id: mypy + additional_dependencies: + - "numpy" + - repo: https://github.com/doublify/pre-commit-rust + rev: v1.0 + hooks: + - id: fmt + - id: cargo-check + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + additional_dependencies: + - tomli + - repo: https://github.com/google/yamlfmt + rev: v0.12.1 + hooks: + - id: yamlfmt + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.17 + hooks: + - id: mdformat diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..dd80752 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,13 @@ +version: 2 +build: + os: ubuntu-22.04 + tools: + python: "3.10" +sphinx: + configuration: docs/source/conf.py +python: + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..18865b6 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "_utils_rust" +version = "0.1.0" +edition = "2021" + +[lib] +name = "_utils_rust" +crate-type = ["cdylib"] + +[dependencies] +bincode = { version = "1.3" } +indexmap = { version = "2.1.0", features = ["rayon"] } +itertools = { version = "0.12.1" } +ndarray = { version = "0.15.6", features = ["rayon"] } +ndarray-stats = { version = "0.5.1" } +num = { version = "0.4.1" } +numpy = { version = "0.21.0" } +polars = { version = "0.39", features = ["partition_by", "dtype-categorical"] } +pyo3 = { version = "0.21.0", features = ["extension-module"] } +pyo3-polars = { version = "0.13.0" } +rayon = { version = "1.8.0" } +sprs = { version = "0.11.1", features = ["serde"] } diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8d75ed7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Niklas Müller-Bötticher, Naveed Ishaque, Roland Eils, Berlin Institute of Health @ Charité + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a53347c --- /dev/null +++ b/README.md @@ -0,0 +1,42 @@ +# sainsc + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) +[![Checked with mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) +[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) + +/ˈsaiəns/ + +_"**S**egmentation-free **A**nalysis of **In S**itu **C**apture data"_ +or alternatively +"_**S**tupid **A**cronyms **In Sc**ience_" + +`sainsc` is a segmentation-free analysis tool for spatial transcriptomics from in situ +capture technologies (but also works for imaging-based technologies). It is easily +integratable with the [scverse](https://github.com/scverse) (i.e. `scanpy` and `squidpy`) +by exporting data in [`AnnData`](https://anndata.readthedocs.io/) or +[`SpatialData`](https://spatialdata.scverse.org/) format. + +## Installation + +`sainsc` will be made available on [PyPI](https://pypi.org/) and +[bioconda](https://bioconda.github.io/). For detailed installation instructions +please refer to the [documentation](https://sainsc.readthedocs.io/en/stable/installation.html). + +## Documentation + +For an extensive documentation of the package please refer to the +[ReadTheDocs page](https://sainsc.readthedocs.io) + +## Versioning + +This project follows the [SemVer](https://semver.org/) guidelines for versioning. + +## Citations + +## License + +This project is licensed under the MIT License - for details please refer to the +[LICENSE](./LICENSE) file. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..ac97759 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/_templates/python/class.rst b/docs/source/_templates/python/class.rst new file mode 100644 index 0000000..f55ae76 --- /dev/null +++ b/docs/source/_templates/python/class.rst @@ -0,0 +1,117 @@ +{% if obj.display %} + {% if is_own_page %} +{{ obj.id }} +{{ "=" * obj.id | length }} + + {% endif %} + {% set visible_children = obj.children|selectattr("display")|list %} + {% set own_page_children = visible_children|selectattr("type", "in", own_page_types)|list %} + {% if is_own_page and own_page_children %} +.. toctree:: + :hidden: + + {% for child in own_page_children %} + {{ child.include_path }} + {% endfor %} + + {% endif %} +.. py:{{ obj.type }}:: {% if is_own_page %}{{ obj.id }}{% else %}{{ obj.short_name }}{% endif %}{% if obj.args %}({{ obj.args }}){% endif %} + + {% for (args, return_annotation) in obj.overloads %} + {{ " " * (obj.type | length) }} {{ obj.short_name }}{% if args %}({{ args }}){% endif %} + + {% endfor %} + {% if obj.bases %} + {% if "show-inheritance" in autoapi_options %} + + Bases: {% for base in obj.bases %}{{ base|link_objs }}{% if not loop.last %}, {% endif %}{% endfor %} + {% endif %} + + + {% if "show-inheritance-diagram" in autoapi_options and obj.bases != ["object"] %} + .. autoapi-inheritance-diagram:: {{ obj.obj["full_name"] }} + :parts: 1 + {% if "private-members" in autoapi_options %} + :private-bases: + {% endif %} + + {% endif %} + {% endif %} + {% if obj.docstring %} + + {{ obj.docstring|indent(3) }} + {% endif %} + {% for obj_item in visible_children %} + {% if obj_item.type not in own_page_types %} + + {{ obj_item.render()|indent(3) }} + {% endif %} + {% endfor %} + {% if is_own_page and own_page_children %} + {% set visible_attributes = own_page_children|selectattr("type", "equalto", "attribute")|list %} + {% if visible_attributes %} +Attributes +---------- + +.. autoapisummary:: + + {% for attribute in visible_attributes %} + {{ attribute.id }} + {% endfor %} + + + {% endif %} + {% set visible_properties = own_page_children|selectattr("type", "equalto", "property")|list %} + {% if visible_properties %} +Properties +---------- + +.. autoapisummary:: + + {% for property in visible_properties %} + {{ property.id }} + {% endfor %} + + + {% endif %} + {% set visible_exceptions = own_page_children|selectattr("type", "equalto", "exception")|list %} + {% if visible_exceptions %} +Exceptions +---------- + +.. autoapisummary:: + + {% for exception in visible_exceptions %} + {{ exception.id }} + {% endfor %} + + + {% endif %} + {% set visible_classes = own_page_children|selectattr("type", "equalto", "class")|list %} + {% if visible_classes %} +Classes +------- + +.. autoapisummary:: + + {% for klass in visible_classes %} + {{ klass.id }} + {% endfor %} + + + {% endif %} + {% set visible_methods = own_page_children|selectattr("type", "equalto", "method")|list %} + {% if visible_methods %} +Methods +------- + +.. autoapisummary:: + + {% for method in visible_methods %} + {{ method.id }} + {% endfor %} + + + {% endif %} + {% endif %} +{% endif %} diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..f317c6e --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,90 @@ +import importlib.metadata +import os +import sys +from datetime import datetime + +sys.path.insert(0, os.path.abspath("../..")) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "sainsc" +copyright = f""" +{datetime.now():%Y}, Niklas Müller-Bötticher, Naveed Ishaque, Roland Eils, +Berlin Institute of Health @ Charité""" +author = "Niklas Müller-Bötticher" +version = importlib.metadata.version("sainsc") +release = version + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +extensions = [ + "sphinx_copybutton", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "autoapi.extension", + "sphinx.ext.mathjax", +] + + +autodoc_typehints = "none" +autodoc_typehints_format = "short" + +autoapi_dirs = ["../../sainsc"] +autoapi_options = [ + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "imported-members", +] +autoapi_python_class_content = "both" +autoapi_own_page_level = "attribute" +autoapi_template_dir = "_templates" +autoapi_member_order = "groupwise" + +python_use_unqualified_type_names = True # still experimental + +autosummary_generate = True +autosummary_imported_members = True + +nitpicky = True +nitpick_ignore = [ + ("py:class", "numpy.typing.DTypeLike"), + ("py:class", "numpy.typing.NDArray"), + ("py:mod", "polars"), + ("py:class", "polars.DataFrame"), + ("py:class", "optional"), +] + +exclude_patterns: list[str] = ["_templates"] + +intersphinx_mapping = dict( + anndata=("https://anndata.readthedocs.io/en/stable/", None), + matplotlib=("https://matplotlib.org/stable/", None), + numpy=("https://numpy.org/doc/stable/", None), + pandas=("https://pandas.pydata.org/pandas-docs/stable/", None), + polars=("https://docs.pola.rs/py-polars/html/", None), + python=("https://docs.python.org/3", None), + scipy=("https://docs.scipy.org/doc/scipy/", None), + seaborn=("https://seaborn.pydata.org/", None), + spatialdata=("https://spatialdata.scverse.org/en/stable/", None), +) + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "sphinx_rtd_theme" +html_static_path: list[str] = [] + + +def skip_submodules(app, what, name, obj, skip, options): + if what == "module": + skip = True + return skip + + +def setup(sphinx): + sphinx.connect("autoapi-skip-member", skip_submodules) diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..ee2de68 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,25 @@ +What is sainsc? +=================== + +sainsc (pronounced /ˈsaiəns/) is a segmentation-free analysis tool for spatial +transcriptomics from in situ capture technologies (but also works for +imaging-based technologies). + +It is easily integratable with the `scverse `_ +(i.e. `scanpy` and `squidpy`) by exporting data in +`AnnData `_ or +`SpatialData `_ format. + +.. Citations +.. --------- + +.. tbd + +.. toctree:: + :maxdepth: 1 + :caption: Contents: + + self + quickstart + installation + usage diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 0000000..5843594 --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,61 @@ +Installation +============ + + +PyPI and ``pip`` +---------------- + +``sainsc`` will soon be available to install from `PyPI `_. + +.. To install ``sainsc`` from `PyPI `_ using ``pip`` just run + +.. .. code-block:: bash + +.. pip install sainsc + +.. If you want to have support for :py:mod:`spatialdata` use + +.. .. code-block:: bash + +.. pip install 'sainsc[spatialdata]' + + +Bioconda and ``conda`` +---------------------- + +``sainsc`` is not yet available for +`Miniconda `_ installations. But we are +planning to add it to the `bioconda `_ channel soon. + + +.. Alternatively, if you prefer the installation using +.. `Miniconda `_ you can do that from the +.. `bioconda `_ channel. + +.. .. code-block:: bash + +.. conda install -c bioconda sainsc + +.. .. note:: + +.. Of course, it is also possible to use ``mamba`` instead of ``conda`` +.. to speed up the installation. + + +From GitHub +----------- + +You can install the latest versions directly from +[GitHub](https://github.com/HiDiHlabs/sainsc). To do so clone the repository using the +``git clone`` command. Navigate into the downloaded directory and install using + +.. code-block:: bash + + pip install . + +If you want to install the development version you can install the additional optional +dependencies with + +.. code-block:: bash + + pip install -e '.[dev]' diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst new file mode 100644 index 0000000..33e746c --- /dev/null +++ b/docs/source/quickstart.rst @@ -0,0 +1,23 @@ +Quickstart / tl;dr +================== + +The general workflow will look like this: + +You first read your data into an :py:class:`sainsc.LazyKDE` object. +The data can be filtered, subset, and cropped to adjust the desired field of view and +genes. + +In the next step the kernel for kernel density estimation (KDE) is defined and +cell types are assigned to each pixel using cell type gene +expression signatures from e.g. single-cell RNAseq. + +Otherwise you can find the local maxima of the KDE and treat these as proxies for cells. +From that point on you can proceed using standard single-cell RNAseq analysis and +spatial methods (e.g. using `scanpy `_ and +`squidpy `_). + +Along the way you will want to (and should) generate a lot of plots to check your +results. + +For a more concrete example of what a workflow can look like we will provide example +notebooks soon. diff --git a/docs/source/usage.rst b/docs/source/usage.rst new file mode 100644 index 0000000..d4b9a15 --- /dev/null +++ b/docs/source/usage.rst @@ -0,0 +1,2 @@ +Usage +===== diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..89d262a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,84 @@ +[build-system] +requires = ["setuptools>=61.0.0", "setuptools_scm[toml]>=6.2", "setuptools-rust>=1.7"] +build-backend = "setuptools.build_meta" + + +[project] +name = "sainsc" +description = "Segmentation-free Analysis of In Situ Capture data" +readme = { file = "README.md", content-type = "text/markdown" } +license = { file = "LICENSE" } +requires-python = ">=3.10" +dynamic = ["version"] + +authors = [ + { name = "Niklas Müller-Bötticher", email = "niklas.mueller-boetticher@charite.de" }, +] +dependencies = [ + "anndata>=0.9", + "matplotlib", + "matplotlib-scalebar", + "numba>=0.44", + "numpy>=1.21", + "pandas", + "polars[pandas]", + "scikit-image>=0.18", + "scipy>=1.9", + "seaborn>=0.11", +] +classifiers = [ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Rust", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "Typing :: Typed", +] + +[project.optional-dependencies] +spatialdata = ["spatialdata>=0.1"] +docs = ["sphinx", "sphinx-autoapi>=3.1", "sphinx-copybutton", "sphinx-rtd-theme"] +dev = ["sainsc[docs]", "pre-commit"] + +[project.urls] +homepage = "https://github.com/HiDiHlabs/sainsc" +documentation = "https://sainsc.readthedocs.io" +repository = "https://github.com/HiDiHlabs/sainsc" + + +[tool] + +[tool.setuptools] +zip-safe = false + +[tool.setuptools.packages.find] +# TODO improve +include = ["sainsc", "sainsc.io", "sainsc.lazykde"] + +[tool.setuptools_scm] + +[[tool.setuptools-rust.ext-modules]] +target = "sainsc._utils_rust" + +[tool.isort] +profile = "black" + +[tool.black] +target-version = ["py310", "py311", "py312"] + +[tool.ruff] +target-version = "py310" + +[tool.mypy] +python_version = "3.10" +ignore_missing_imports = true +warn_no_return = false +packages = "sainsc" +plugins = "numpy.typing.mypy_plugin" + +[tool.codespell] +ignore-words-list = "coo,crate" diff --git a/sainsc/__init__.py b/sainsc/__init__.py new file mode 100644 index 0000000..a152f6d --- /dev/null +++ b/sainsc/__init__.py @@ -0,0 +1,19 @@ +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("sainsc") +except PackageNotFoundError: + __version__ = "unknown version" + +from sainsc._utils_rust import GridCounts + +from .io import read_StereoSeq, read_StereoSeq_bins +from .lazykde import LazyKDE, gaussian_kernel + +__all__ = [ + "GridCounts", + "LazyKDE", + "gaussian_kernel", + "read_StereoSeq", + "read_StereoSeq_bins", +] diff --git a/sainsc/_typealias.py b/sainsc/_typealias.py new file mode 100644 index 0000000..94a2df1 --- /dev/null +++ b/sainsc/_typealias.py @@ -0,0 +1,21 @@ +import os +from typing import TypeAlias + +import numpy as np +from numpy.typing import NDArray +from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix + +_PathLike: TypeAlias = os.PathLike[str] | str + +_Csr: TypeAlias = csr_array | csr_matrix +_Csc: TypeAlias = csc_array | csc_matrix +_Csx: TypeAlias = _Csr | _Csc +_CsxArray: TypeAlias = csc_array | csr_array + +_RangeTuple: TypeAlias = tuple[int, int] +_RangeTuple2D: TypeAlias = tuple[_RangeTuple, _RangeTuple] + +_Local_Max: TypeAlias = tuple[NDArray[np.int_], NDArray[np.int_]] + +_Color: TypeAlias = str | tuple[float, float, float] +_Cmap: TypeAlias = str | list[_Color] | dict[str, _Color] diff --git a/sainsc/_utils.py b/sainsc/_utils.py new file mode 100644 index 0000000..55f0d1e --- /dev/null +++ b/sainsc/_utils.py @@ -0,0 +1,45 @@ +import os +from typing import NoReturn + +import numpy as np +import pandas as pd +from numpy.typing import NDArray + +from sainsc._utils_rust import coordinate_as_string + + +def _get_n_cpus() -> int: + return len(os.sched_getaffinity(0)) + + +def _get_coordinate_index( + x: NDArray[np.integer], + y: NDArray[np.integer], + *, + name: str | None = None, + n_threads: int = 1, +) -> pd.Index: + x_i32: NDArray[np.int32] = x.astype(np.int32, copy=False) + y_i32: NDArray[np.int32] = y.astype(np.int32, copy=False) + + return pd.Index( + coordinate_as_string(x_i32, y_i32, n_threads=n_threads), dtype=str, name=name + ) + + +def _bin_coordinates(df: pd.DataFrame, bin_size: float) -> pd.DataFrame: + df = df.assign( + x=lambda df: _get_bin_coordinate(df["x"].to_numpy(), bin_size), + y=lambda df: _get_bin_coordinate(df["y"].to_numpy(), bin_size), + ) + return df + + +def _get_bin_coordinate(coor: NDArray[np.number], bin_size: float) -> NDArray[np.int32]: + return np.floor(coor / bin_size).astype(np.int32, copy=False) + + +def _raise_module_load_error(e: Exception, fn: str, pkg: str, extra: str) -> NoReturn: + raise ModuleNotFoundError( + f"`{fn}` requires '{pkg}' to be installed, e.g. via the '{extra}' extra." + ) from e diff --git a/sainsc/_utils_rust.pyi b/sainsc/_utils_rust.pyi new file mode 100644 index 0000000..d6fcde0 --- /dev/null +++ b/sainsc/_utils_rust.pyi @@ -0,0 +1,258 @@ +import numpy as np +from numpy.typing import NDArray +from polars import DataFrame + +from ._typealias import _Csx, _CsxArray + +def sparse_kde_csx_py( + counts: _Csx, kernel: NDArray[np.float32], *, threshold: float = 0 +) -> _CsxArray: + """ + Calculate the KDE for each spot with counts as uint16. + """ + +def kde_at_coord( + counts: GridCounts, + genes: list[str], + kernel: NDArray[np.float32], + coordinates: tuple[NDArray[np.int_], NDArray[np.int_]], + *, + n_threads: int | None = None, +) -> _CsxArray: + """ + Calculate KDE at the given coordinates. + """ + +def categorical_coordinate( + x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int | None = None +) -> tuple[NDArray[np.int32], NDArray[np.int32]]: + """ + Get the codes and the coordinates (comparable to a pandas.Categorical) + """ + +def coordinate_as_string( + x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int | None = None +) -> NDArray[np.str_]: + """ + Concatenate two int arrays elementwise into a string representation (i.e. 'x_y'). + """ + +def cosinef32_and_celltypei8( + counts: GridCounts, + genes: list[str], + signatures: NDArray[np.float32], + kernel: NDArray[np.float32], + *, + log: bool = False, + chunk_size: tuple[int, int] = (500, 500), + n_threads: int | None = None, +) -> tuple[NDArray[np.float32], NDArray[np.int8]]: + """ + Calculate the cosine similarity given counts and signatures and assign the most + similar celltype. + """ + +def cosinef32_and_celltypei16( + counts: GridCounts, + genes: list[str], + signatures: NDArray[np.float32], + kernel: NDArray[np.float32], + *, + log: bool = False, + chunk_size: tuple[int, int] = (500, 500), + n_threads: int | None = None, +) -> tuple[NDArray[np.float32], NDArray[np.int16]]: + """ + Calculate the cosine similarity given counts and signatures and assign the most + similar celltype. + """ + +class GridCounts: + """ + Object holding each gene as count data in a sparse 2D-grid. + """ + + shape: tuple[int, int] + """ + tuple[int, int]: Shape of the count arrays. + """ + + def __init__( + self, + counts: dict[str, _Csx], + *, + resolution: float | None = None, + n_threads: int | None = None, + ): + """ + Parameters + ---------- + counts : dict[str, scipy.sparse.csr_array | scipy.sparse.csr_matrix| scipy.sparse.csc_array| scipy.sparse.csc_matrix] + Gene counts. + resolution : float, optional + Resolution as nm / pixel. + n_threads : int, optional + Number of threads used for reading and processing file. If `None` this will + default to the number of logical CPUs. + + Raises + ------ + ValueError + If genes in `counts` do not all have the same shape. + """ + + @classmethod + def from_dataframe( + cls, + df: DataFrame, + *, + resolution: float | None = None, + binsize: float | None = None, + n_threads: int | None = None, + ): # -> Self + """ + Initialize from dataframe. + + Transform a :py:class:`polars.DataFrame` that provides a 'gene', 'x', and 'y' + column into :py:class:`sainsc.GridCounts`. If a 'count' column exists it will + be used as counts else a count of 1 (single molecule) per row will be assumed. + + Parameters + ---------- + df : polars.DataFrame + The data to be transformed. + binsize : float or None, optional + The size to bin the coordinates by. If None coordinates must be integers. + resolution : float, optional + Resolution of each coordinate unit in nm. The default is 1,000 i.e. measurements + are in um. + n_threads : int, optional + Number of threads used for initializing :py:class:`sainsc.LazyKDE`. + If `None` this will default to the number of logical CPUs. + + Returns + ------- + sainsc.GridCounts + """ + + def __getitem__(self, key: str) -> _CsxArray: ... + def __setitem__(self, key: str, value: _Csx): ... + def __delitem__(self, key: str): ... + def __len__(self) -> int: ... + def __contains__(self, item: str) -> bool: ... + def __eq__(self, other) -> bool: ... + def __ne__(self, other) -> bool: ... + def get(self, key: str, default: _CsxArray | None = None) -> _CsxArray | None: + """ + Get the counts for a gene. + + Parameters + ---------- + key : str + Name of the gene to retrieve. + + Returns + ------- + scipy.sparse.csr_array | scipy.sparse.csc_array | None + """ + + def genes(self) -> list[str]: + """ + Get all available genes. + + Returns + ------- + list[str] + """ + + def gene_counts(self) -> dict[str, int]: + """ + Number of counts per gene. + + Returns + ------- + dict[str, int] + Mapping from gene to number of counts. + """ + + def grid_counts(self) -> NDArray[np.uintc]: + """ + Counts per pixel. + + Aggregates counts across all genes. + + Returns + ------- + numpy.ndarray[numpy.uintc] + """ + + def select_genes(self, genes: set[str]): + """ + Keep selected genes. + + Parameters + ---------- + genes : set[str] + List of gene names to keep. + """ + + def filter_genes_by_count(self, min: int = 1, max: int = 4_294_967_295): + """ + Filter genes by minimum and maximum count thresholds. + + Parameters + ---------- + min : int, optional + Minimum count threshold. + max : int, optional + Maximum count threshold. + """ + + def crop(self, x: tuple[int | None, int | None], y: tuple[int | None, int | None]): + """ + Crop the field of view for all genes. + + Parameters + ---------- + x : tuple[int | None, int | None] + Range to crop as `(xmin, xmax)` + y : tuple[int | None, int | None] + Range to crop as `(ymin, ymax)` + """ + + def filter_mask(self, mask: NDArray[np.bool_]): + """ + Filter all genes with a binary mask. + + Parameters + ---------- + mask : numpy.ndarray[numpy.bool] + All counts where `mask` is `False` will be set to 0. + """ + + @property + def resolution(self) -> float | None: + """ + float | None: Resolution in nm / pixel. + + Raises + ------ + TypeError + If setting with a type other than `float` or `int`. + """ + + @resolution.setter + def resolution(self, resolution: float): ... + @property + def n_threads(self) -> int: + """ + int: Number of threads used for processing. + + Raises + ------ + TypeError + If setting with a type other than `int` or less than 0. + """ + + @n_threads.setter + def n_threads(self, n_threads: int): ... diff --git a/sainsc/io/__init__.py b/sainsc/io/__init__.py new file mode 100644 index 0000000..373bd74 --- /dev/null +++ b/sainsc/io/__init__.py @@ -0,0 +1,3 @@ +from ._io import read_gem_file, read_StereoSeq, read_StereoSeq_bins + +__all__ = ["read_gem_file", "read_StereoSeq", "read_StereoSeq_bins"] diff --git a/sainsc/io/_io.py b/sainsc/io/_io.py new file mode 100644 index 0000000..8d80432 --- /dev/null +++ b/sainsc/io/_io.py @@ -0,0 +1,304 @@ +from pathlib import Path +from typing import TYPE_CHECKING, Literal, get_args + +import numpy as np +import pandas as pd +import polars as pl +from anndata import AnnData +from scipy.sparse import csr_matrix + +from .._typealias import _PathLike +from .._utils import ( + _bin_coordinates, + _get_coordinate_index, + _get_n_cpus, + _raise_module_load_error, +) +from .._utils_rust import GridCounts +from ..lazykde import LazyKDE +from ._io_utils import _categorical_coordinate, _open_file +from ._stereoseq_chips import CHIP_RESOLUTION + +if TYPE_CHECKING: + from spatialdata import SpatialData + + +# Stereo-seq +def _get_stereo_header(filepath: _PathLike) -> dict[str, str]: + header = dict() + with _open_file(filepath, "rb") as f: + for line in f: + assert isinstance(line, bytes) + if not line.startswith(b"#"): + break + key, value = line.decode().strip("#\n").split("=", 1) + header[key] = value + + return header + + +def _get_stereo_resolution(name: str) -> int | None: + for chip_name in CHIP_RESOLUTION.keys(): + if name.startswith(chip_name): + return CHIP_RESOLUTION[chip_name] + + +def read_gem_file( + filepath: _PathLike, *, sep: str = "\t", n_threads: int | None = None, **kwargs +) -> pl.DataFrame: + """ + Read a GEM file into a DataFrame. + + GEM files are used by e.g. Stereo-Seq and Nova-ST. + + The gene-ID and count column will be renamed to `gene` and `count`, respectively. + + The name of the count column must be one of `MIDCounts`, `MIDCount`, or `UMICount`. + + Parameters + ---------- + filepath : os.PathLike or str + Path to the GEM file. + sep : str, optional + Separator used in :py:func:`polars.read_csv`. + n_threads : int, optional + Number of threads used for reading and processing file. If `None` this will + default to the number of available CPUs. + kwargs + Other keyword arguments will be passed to :py:func:`polars.read_csv`. + + Returns + ------- + polars.DataFrame + + Raises + ------ + ValueError + If count column has an unknown name. + """ + _Count_ColName = Literal["MIDCounts", "MIDCount", "UMICount"] + + if n_threads is None: + n_threads = _get_n_cpus() + + path = Path(filepath) + + columns = pl.read_csv(path, separator=sep, comment_char="#", n_rows=0).columns + count_col = None + for name in get_args(_Count_ColName): + if name in columns: + count_col = name + break + + if count_col is None: + options = get_args(_Count_ColName) + raise ValueError( + f"Unknown count column, the name of the column must be one of {options}" + ) + df = pl.read_csv( + path, + separator=sep, + comment_char="#", + dtypes={ + "geneID": pl.Categorical, + "x": pl.Int32, + "y": pl.Int32, + count_col: pl.UInt32, + }, + n_threads=n_threads, + **kwargs, + ) + df = df.rename({count_col: "count", "geneID": "gene"}) + + return df + + +def read_StereoSeq( + filepath: _PathLike, + *, + resolution: float | None = None, + sep: str = "\t", + n_threads: int | None = None, + **kwargs, +) -> LazyKDE: + """ + Read a Stereo-seq GEM file. + + Parameters + ---------- + filepath : os.PathLike or str + Path to the Stereo-seq file. + resolution : float, optional + Center-to-center distance of Stere-seq beads in nm, if None + it will try to detect it from the chip definition in the file header + if one exists. + sep : str, optional + Separator used in :py:func:`polars.read_csv`. + n_threads : int, optional + Number of threads used for reading and processing file. If `None` this will + default to the number of available CPUs. + kwargs + Other keyword arguments will be passed to :py:func:`polars.read_csv`. + + Returns + ------- + sainsc.LazyKDE + """ + if n_threads is None: + n_threads = _get_n_cpus() + + if resolution is None: + chip = _get_stereo_header(filepath).get("StereoChip") + if chip is not None: + resolution = _get_stereo_resolution(chip) + + df = read_gem_file(filepath, sep=sep, n_threads=n_threads, **kwargs) + + counts = GridCounts.from_dataframe( + df, binsize=None, resolution=resolution, n_threads=n_threads + ) + + return LazyKDE(counts, n_threads=n_threads) + + +def read_StereoSeq_bins( + filepath: _PathLike, + bin_size: int = 50, + *, + spatialdata: bool = False, + resolution: float | None = None, + sep: str = "\t", + n_threads: int | None = None, + **kwargs, +) -> "AnnData | SpatialData": + """ + Read a Stereo-seq GEM file into bins. + + Parameters + ---------- + filepath : os.PathLike or str + Path to the Stereo-seq file. + bin_size : int, optional + Defines the size of bins along both dimensions + e.g 50 will results in bins of size 50x50. + spatialdata : bool, optional + If True will load the data as a SpatialData object else as an AnnData object. + resolution : float, optional + Center-to-center distance of Stere-seq beads in nm, if None + it will try to detect it from the chip definition in the file header + if one exists. + sep : str, optional + Separator used in :py:func:`polars.read_csv`. + n_threads : int, optional + Number of threads used for reading and processing file. If `None` this will + default to the number of available CPUs. + kwargs + Other keyword arguments will be passed to :py:func:`polars.read_csv`. + + Returns + ------- + anndata.AnnData | spatialdata.SpatialData + AnnData or SpatialData object of the bins with coordinates stored in + :py:attr:`anndata.AnnData.obsm` with the key `'spatial'`. + + Raises + ------ + ModuleNotFoundError + If `spatialdata` is set to `True` but the package is not installed. + """ + if n_threads is None: + n_threads = _get_n_cpus() + + header = _get_stereo_header(filepath) + df = read_gem_file(filepath, sep=sep, n_threads=n_threads, **kwargs) + df = df.with_columns(pl.col(i) - pl.col(i).min() for i in ["x", "y"]) + df = _bin_coordinates(df.to_pandas(), bin_size) + + coord_codes, coordinates = _categorical_coordinate( + df.pop("x").to_numpy(), df.pop("y").to_numpy(), n_threads=n_threads + ) + + # Duplicate entries in csr_matrix are summed which automatically gives bin merging + counts = csr_matrix( + (df.pop("count"), (coord_codes, df["gene"].cat.codes)), + shape=(coordinates.shape[0], df["gene"].cat.categories.size), + dtype=np.int32, + ) + + del coord_codes + + obs = pd.DataFrame( + index=_get_coordinate_index( + coordinates[:, 0], coordinates[:, 1], name="bin", n_threads=n_threads + ) + ) + genes = pd.DataFrame(index=pd.Index(df["gene"].cat.categories, name="gene")) + + if resolution is None: + chip = header.get("StereoChip") + if chip is not None: + resolution = _get_stereo_resolution(chip) + + adata = AnnData( + X=counts, + obs=obs, + var=genes, + obsm={"spatial": coordinates}, + uns={"file_header": header, "resolution": resolution, "bin_size": bin_size}, + ) + + if spatialdata: + try: + from geopandas import GeoDataFrame + from shapely import Polygon + from spatialdata import SpatialData + from spatialdata.models import ShapesModel, TableModel + + bin_name = f"bins{bin_size}" + + x, y = adata.obsm["spatial"].T + del adata.obsm["spatial"] + + df = pd.DataFrame({"x": x, "y": y}, index=adata.obs_names).assign( + x1=lambda df: df["x"] * bin_size, + x2=lambda df: df["x1"] + bin_size, + y1=lambda df: df["y"] * bin_size, + y2=lambda df: df["y1"] + bin_size, + ) + + shapes = ShapesModel.parse( + GeoDataFrame( + { + "geometry": df.apply( + lambda r: Polygon( + [ + (r["x1"], r["y1"]), + (r["x1"], r["y2"]), + (r["x2"], r["y2"]), + (r["x2"], r["y1"]), + ] + ), + axis=1, + ) + } + ) + ) + + adata.obs["region"] = bin_name + adata.obs["region"] = adata.obs["region"].astype("category") + adata.obs["instance_key"] = adata.obs_names + table = TableModel.parse( + adata, region=bin_name, region_key="region", instance_key="instance_key" + ) + + return SpatialData( + shapes={bin_name: shapes}, tables={f"{bin_name}_annotation": table} + ) + + except ModuleNotFoundError as e: + _raise_module_load_error( + e, "read_StereoSeq_bins", pkg="spatialdata", extra="spatialdata" + ) + + else: + return adata diff --git a/sainsc/io/_io_utils.py b/sainsc/io/_io_utils.py new file mode 100644 index 0000000..05f47bc --- /dev/null +++ b/sainsc/io/_io_utils.py @@ -0,0 +1,30 @@ +import gzip +from pathlib import Path +from typing import Literal + +import numpy as np +from numpy.typing import NDArray + +from sainsc._utils_rust import categorical_coordinate + +from .._typealias import _PathLike + + +def _categorical_coordinate( + x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int = 1 +) -> tuple[NDArray[np.int32], NDArray[np.int32]]: + assert len(x) == len(y) + + return categorical_coordinate(x, y, n_threads=n_threads) + + +# currently no need to support all file modes +_File_Mode = Literal["r", "w", "rb", "wb"] + + +def _open_file(file: _PathLike, mode: _File_Mode = "r"): + file = Path(file) + if file.suffix == ".gz": + return gzip.open(file, mode) + else: + return open(file, mode) diff --git a/sainsc/io/_stereoseq_chips.py b/sainsc/io/_stereoseq_chips.py new file mode 100644 index 0000000..add4383 --- /dev/null +++ b/sainsc/io/_stereoseq_chips.py @@ -0,0 +1,28 @@ +CHIP_RESOLUTION = { + "A": 500, + "B": 500, + "C": 500, + "CL1": 900, + "D": 500, + "DP40": 700, + "DP8": 850, + "DP84": 715, + "E1": 700, + "F1": 800, + "F3": 715, + "FP1": 600, + "FP2": 500, + "G1": 700, + "K2": 715, + "N1": 900, + "S1": 900, + "S2": 715, + "SS2": 500, + "U": 715, + "V": 715, + "V1": 800, + "V3": 715, + "W": 715, + "X": 715, + "Y": 500, +} diff --git a/sainsc/lazykde/_LazyKDE.py b/sainsc/lazykde/_LazyKDE.py new file mode 100644 index 0000000..9a96319 --- /dev/null +++ b/sainsc/lazykde/_LazyKDE.py @@ -0,0 +1,1092 @@ +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from anndata import AnnData +from matplotlib.axes import Axes +from matplotlib.colors import to_rgb +from matplotlib.figure import Figure +from matplotlib.lines import Line2D +from matplotlib_scalebar.scalebar import ScaleBar +from mpl_toolkits import axes_grid1 +from numba import njit +from numpy.typing import NDArray +from scipy.sparse import coo_array, csc_array, csr_array +from skimage.feature import peak_local_max + +from sainsc._utils_rust import ( + GridCounts, + cosinef32_and_celltypei8, + cosinef32_and_celltypei16, + kde_at_coord, + sparse_kde_csx_py, +) + +from .._typealias import _Cmap, _Csx, _CsxArray, _Local_Max, _RangeTuple2D +from .._utils import _get_n_cpus, _raise_module_load_error +from ._kernel import gaussian_kernel +from ._utils import ( + _SCALEBAR, + CosineCelltypeCallable, + _apply_color, + _filter_blobs, + _get_cell_dtype, + _localmax_anndata, +) + +# from typing import Self + +if TYPE_CHECKING: + from spatialdata import SpatialData + + +class LazyKDE: + """ + Class to analyze kernel density estimates (KDE) for large number of genes. + + The KDE of the genes will be calculated when needed to avoid storing large volumes + of data in memory. + """ + + def __init__( + self, + counts: GridCounts, + *, + n_threads: int | None = None, + ): + """ + Parameters + ---------- + counts : sainsc.GridCounts + Gene counts. + n_threads : int, optional + Number of threads used for reading and processing file. If `None` this will + default to the number of available CPUs. + """ + if n_threads is None: + n_threads = _get_n_cpus() + + self.counts: GridCounts = counts + """ + sainsc.GridCounts : Spatial gene counts. + """ + + self.counts.n_threads = n_threads + + self._threads = n_threads + + self._total_mRNA: NDArray[np.unsignedinteger] | None = None + self._total_mRNA_KDE: NDArray[np.float32] | None = None + self._background: NDArray[np.bool_] | None = None + self._local_maxima: _Local_Max | None = None + self._celltype_map: NDArray[np.signedinteger] | None = None + self._cosine_similarity: NDArray[np.float32] | None = None + self._celltypes: list[str] | None = None + + ## Kernel + def gaussian_kernel( + self, sigma: float, *, truncate: float = 2, circular: bool = False + ): + """ + Set the kernel used for kernel density estimation (KDE) to gaussian. + + Parameters + ---------- + sigma : float + Bandwidth of the kernel. + truncate : float, optional + The radius for calculating the KDE is calculated as `sigma` * `truncate`. + Refer to :py:func:`scipy.ndimage.gaussian_filter`. + circular : bool, optional + If `True` calculate the KDE using a circular kernel instead of square by + setting all values outside the radius `sigma` * `truncate` to 0. + + See Also + -------- + :py:meth:`sainsc.LazyKDE.kde` + """ + dtype = np.float32 + radius = round(truncate * sigma) + self.kernel = gaussian_kernel(sigma, radius, dtype=dtype, circular=circular) + + ## KDE + def kde(self, gene: str, *, threshold: float | None = None) -> _CsxArray: + """ + Calculate kernel density estimate (KDE). + + The kernel will be used from :py:attr:`sainsc.LazyKDE.kernel`. + + Parameters + ---------- + gene : collections.abc.Sequence[str] + List of genes for which to calculate the KDE. + threshold : float, optional + Relative threshold of maximum of the kernel that is used to filter beads. + All values below :math:`threshold * max(kernel)` are set to 0. Filtering is done + after calculating the KDE which sets it apart from reducing `truncate`. + + See Also + -------- + :py:meth:`sainsc.LazyKDE.gaussian_kernel` + """ + return self._kde(self.counts[gene], threshold) + + def _kde(self, arr: NDArray | _Csx, threshold: float | None = None) -> _CsxArray: + # TODO use ndimage.gaussian_filter for "dense" arrays? + # ndimage.gaussian_filter(arr, sigma, mode="constant", cval=0, truncate=2) + if threshold is None: + threshold = 0 + + if isinstance(arr, np.ndarray): + arr = csr_array(arr) + + if arr.dtype == np.uint32: + return sparse_kde_csx_py(arr, self.kernel, threshold=threshold) + else: + raise TypeError("Sparse KDE currently only supports 'numpy.uint32'") + + def calculate_total_mRNA(self): + """ + Calculate kernel density estimate (KDE) for the total mRNA. + + See Also + -------- + :py:meth:`sainsc.LazyKDE.calculate_total_mRNA_KDE` + """ + + self._total_mRNA = self.counts.grid_counts() + + def calculate_total_mRNA_KDE(self): + """ + Calculate kernel density estimate (KDE) for the total mRNA. + + If :py:attr:`sainsc.LazyKDE.total_mRNA` has not been calculated + :py:meth:`sainsc.LazyKDE.calculate_total_mRNA` is run first. + + See Also + -------- + :py:meth:`sainsc.LazyKDE.gaussian_kernel` + :py:meth:`sainsc.LazyKDE.kde` + """ + if self.total_mRNA is None or self.total_mRNA.shape != self.shape: + self.calculate_total_mRNA() + total_mRNA_counts = self.total_mRNA + assert total_mRNA_counts is not None + self._total_mRNA_KDE = self._kde(total_mRNA_counts).toarray() + + ## Local maxima / cell proxies + def find_local_maxima(self, min_dist: int, min_area: int = 0): + """ + Find the local maxima of the kernel density estimates. + + The local maxima are detected from the KDE of the total mRNA stored in + :py:attr:`sainsc.LazyKDE.total_mRNA_KDE`. Background as defined in + :py:attr:`sainsc.LazyKDE.background` will be removed before identifying + local maxima. + + Parameters + ---------- + min_dist : int + Minimum distance between two maxima in pixels. + min_area : int, optional + Minimum area of connected pixels that are not background to be + considered for maxima detection. Allows ignoring maxima in noisy spots. + """ + if self.total_mRNA_KDE is None: + raise ValueError( + "`total_mRNA_KDE` must be calculated before finding local maxima" + ) + + if self.background is not None: + foreground = ~self.background + if min_area > 0: + foreground = _filter_blobs(foreground, min_area) + else: + foreground = None + + local_max = peak_local_max( + self.total_mRNA_KDE, + min_distance=min_dist, + exclude_border=False, + labels=foreground, + ) + + self._local_maxima = (local_max[:, 0], local_max[:, 1]) + + def load_local_maxima( + self, genes: Iterable[str] | None = None, *, spatialdata: bool = False + ) -> "AnnData | SpatialData": + """ + Load the gene expression (KDE) of the local maxima. + + The local maxima (:py:attr:`sainsc.LazyKDE.local_maxima`) are calculated and + returned as :py:class:`anndata.AnnData` object. + + Parameters + ---------- + genes : collections.abc.Iterable[str], optional + List of genes for which the KDE will be calculated. + spatialdata : bool, optional + If True will load the data as a SpatialData object including the totalRNA + projection and cell-type map if available. If False an AnnData object is + returned. + + Returns + ------- + anndata.AnnData | spatialdata.SpatialData + + Raises + ------ + ModuleNotFoundError + If `spatialdata` is set to `True` but the package is not installed. + """ + if self.local_maxima is None: + raise ValueError("`local_maxima` have to be identified before loading") + + if genes is None: + genes = self.genes + else: + genes = list(genes) + + kde = self._load_KDE_maxima(genes) + adata = _localmax_anndata( + kde, genes, self.local_maxima, name="local_maxima", n_threads=self.n_threads + ) + + if spatialdata: + try: + from spatialdata import SpatialData + from spatialdata.models import ( + Image2DModel, + Labels2DModel, + PointsModel, + TableModel, + ) + + x, y = adata.obsm["spatial"].T + del adata.obsm["spatial"] + + localmax_name = "local_maxima" + + local_max = PointsModel.parse( + pd.DataFrame({"x": x, "y": y}, index=adata.obs_names) + ) + + adata.obs["region"] = localmax_name + adata.obs["region"] = adata.obs["region"].astype("category") + adata.obs["instance_key"] = adata.obs_names + + local_max_anno = TableModel.parse( + adata, + region=localmax_name, + region_key="region", + instance_key="instance_key", + ) + + sdata_dict: dict[str, Any] = { + localmax_name: local_max, + f"{localmax_name}_annotation": local_max_anno, + } + + if self.total_mRNA_KDE is not None: + + sdata_dict["total_mRNA"] = Image2DModel.parse( + np.atleast_3d(self.total_mRNA_KDE).T, dims=("c", "y", "x") + ) + + if self.celltype_map is not None: + label_name = "celltype_map" + + labels = self.celltype_map + 1 + if self.background is not None: + labels[self.background] = 0 + + sdata_dict[label_name] = Labels2DModel.parse( + labels.T, dims=("y", "x") + ) + + obs = pd.DataFrame( + {"region": label_name, "instance_key": self.celltypes}, + index=self.celltypes, + ).astype({"region": "category"}) + + sdata_dict[f"{label_name}_annotation"] = TableModel.parse( + AnnData(obs=obs), + region=label_name, + region_key="region", + instance_key="instance_key", + ) + + return SpatialData.from_elements_dict(sdata_dict) + + except ModuleNotFoundError as e: + _raise_module_load_error( + e, "load_local_maxima", pkg="spatialdata", extra="spatialdata" + ) + + else: + return adata + + def _load_KDE_maxima(self, genes: list[str]) -> csc_array | csr_array: + + assert self.local_maxima is not None + return kde_at_coord( + self.counts, genes, self.kernel, self.local_maxima, n_threads=self.n_threads + ) + + ## Celltyping + def filter_background( + self, + min_norm: float | dict[str, float], + min_cosine: float | dict[str, float] | None = None, + ): + """ + Assign beads as background. + + Parameters + ---------- + min_norm : float or dict[str, float] + The threshold for defining background based on + :py:attr:`sainsc.LazyKDE.total_mRNA_KDE`. + Either a float which is used as global threshold or a mapping from cell-types + to thresholds. Cell-type assignment is needed for cell type-specific thresholds. + min_cosine : float or dict[str, float], optional + The threshold for defining background based on the minimum cosine + similarity. Cell type-specific thresholds can be defined as for `min_norm`. + + Raises + ------ + ValueError + If cell type-specific thresholds do not include all cell-types. + """ + + @njit + def _map_celltype_to_value( + ct_map: NDArray[np.integer], dict: dict[int, float] + ) -> NDArray[np.floating]: + values = np.zeros(shape=ct_map.shape, dtype=float) + for i in range(ct_map.shape[0]): + for j in range(ct_map.shape[1]): + if ct_map[i, j] >= 0: + values[i, j] = dict[ct_map[i, j]] + return values + + if self.total_mRNA_KDE is None: + raise ValueError( + "`total_mRNA_KDE` needs to be calculated before filtering background" + ) + + if isinstance(min_norm, dict): + if self.celltypes is None or self.celltype_map is None: + raise ValueError( + "Cell type-specific threshold can only be used after cell-type assignment" + ) + elif not all([ct in min_norm.keys() for ct in self.celltypes]): + raise ValueError("'min_norm' does not contain all celltypes.") + idx2threshold = {idx: min_norm[ct] for idx, ct in enumerate(self.celltypes)} + threshold = _map_celltype_to_value(self.celltype_map, idx2threshold) + background = self.total_mRNA_KDE < threshold + else: + background = self.total_mRNA_KDE < min_norm + + if min_cosine is not None: + if self.cosine_similarity is None: + raise ValueError( + "Cosine similarity threshold can only be used after cell-type assignment" + ) + if isinstance(min_cosine, dict): + if self.celltypes is None or self.celltype_map is None: + raise ValueError( + "Cell type-specific threshold can only be used after cell-type assignment" + ) + elif not all([ct in min_cosine.keys() for ct in self.celltypes]): + raise ValueError("'min_cosine' does not contain all celltypes.") + idx2threshold = { + idx: min_cosine[ct] for idx, ct in enumerate(self.celltypes) + } + threshold = _map_celltype_to_value(self.celltype_map, idx2threshold) + background &= self.cosine_similarity >= threshold + else: + background &= self.cosine_similarity >= min_cosine + + self._background = background + + @staticmethod + def _calculate_cosine_celltype_fn(dtype) -> CosineCelltypeCallable: + if dtype == np.int8: + return cosinef32_and_celltypei8 + elif dtype == np.int16: + return cosinef32_and_celltypei16 + else: + raise NotImplementedError + + def assign_celltype( + self, + signatures: pd.DataFrame, + *, + log: bool = False, + chunk: tuple[int, int] = (500, 500), + ): + """ + Calculate the cosine similarity with known cell-type signatures. + + For each bead calculate the cosine similarity with a set of cell-type signatures. + The cell-type with highest score will be assigned to the corresponding bead. + + Parameters + ---------- + signatures : pandas.DataFrame + DataFrame of cell-type signatures. Columns are cell-types and index are genes. + log : bool + Whether to log transform the KDE when calculating the cosine similarity. + This is useful if the gene signatures are derived from log-transformed data. + chunk : tuple[int, int] + Size of the chunks for processing. Larger chunks require more memory but + have less duplicated computation. + """ + + if not all(signatures.index.isin(self.genes)): + raise ValueError( + "Not all genes in the gene signature are part of this KDE." + ) + + if not all(s < c for s, c in zip(self.kernel.shape, chunk)): + raise ValueError("`chunk` must be larger than shape of kernel.") + + dtype = np.float32 + + celltypes = signatures.columns.tolist() + ct_dtype = _get_cell_dtype(len(celltypes)) + + # scale signatures to unit norm + signatures_mat = signatures.to_numpy() + signatures_mat = ( + signatures_mat / np.linalg.norm(signatures_mat, axis=0) + ).astype(dtype, copy=False) + + genes = signatures.index.to_list() + + fn = self._calculate_cosine_celltype_fn(ct_dtype) + + self._cosine_similarity, self._celltype_map = fn( + self.counts, + genes, + signatures_mat, + self.kernel, + log=log, + chunk_size=chunk, + n_threads=self.n_threads, + ) + self._celltypes = celltypes + + ## Plotting + def _plot_2d( + self, + img: NDArray, + title: str, + *, + remove_background: bool = False, + crop: _RangeTuple2D | None = None, + scalebar: bool = True, + im_kwargs: dict = dict(), + scalebar_kwargs: dict = _SCALEBAR, + ) -> Figure: + if remove_background: + if self.background is not None: + img[self.background] = 0 + else: + raise ValueError("`background` is undefined") + + if crop is not None: + img = img[tuple(slice(*c) for c in crop)] + fig, ax = plt.subplots(1, 1) + assert isinstance(ax, Axes) + im = ax.imshow(img.T, origin="lower", **im_kwargs) + ax.set_title(title) + + divider = axes_grid1.make_axes_locatable(ax) + cax = divider.append_axes("right", size="3%", pad=0.1) + fig.colorbar(im, cax=cax) + + if scalebar: + self._add_scalebar(ax, **scalebar_kwargs) + return fig + + def _add_scalebar(self, ax: Axes, **kwargs): + if self.counts.resolution is None: + raise ValueError("'resolution' must be set when using scalebar") + ax.add_artist(ScaleBar(self.counts.resolution, **kwargs)) + + def plot_genecount_histogram(self, **kwargs) -> Figure: + """ + Plot a histogram of the counts per gene. + + Parameters + ---------- + kwargs + Other keyword arguments are passed to :py:func:`seaborn.histplot` + + Returns + ------- + matplotlib.figure.Figure + """ + fig, ax = plt.subplots(1, 1) + assert isinstance(ax, Axes) + sns.histplot( + np.fromiter(self.counts.gene_counts().values(), dtype=int), + log_scale=True, + ax=ax, + **kwargs, + ) + ax.set(xlabel="Counts per gene", ylabel="# genes") + return fig + + def plot_KDE_histogram( + self, + *, + gene: str | None = None, + remove_background: bool = False, + crop: _RangeTuple2D | None = None, + **kwargs, + ) -> Figure: + """ + Plot a histogram of the kernel density estimates. + + Plots either the kernel density estimate (KDE) of the total mRNA + (:py:attr:`sainsc.LazyKDE.total_mRNA_KDE`) or of a single gene if `gene` is + provided. + + Parameters + ---------- + gene : str, optional + Gene for which the KDE histogram is plotted. + remove_background : bool, optional + If `True`, all pixels for which :py:attr:`sainsc.LazyKDE.background` is + `False` are set to 0. + crop : tuple[tuple[int, int], tuple[int, int]], optional + Coordinates to crop the data defined as `((xmin, xmax), (ymin, ymax))`. + kwargs + Other keyword arguments are passed to :py:func:`matplotlib.pyplot.hist`. + + Returns + ------- + matplotlib.figure.Figure + + See Also + -------- + :py:meth:`sainsc.LazyKDE.kde` + """ + name = "total mRNA" if gene is None else gene + + if gene is not None: + kde = self.kde(gene) + else: + if self.total_mRNA_KDE is not None: + kde = self.total_mRNA_KDE + else: + raise ValueError("`total_mRNA_KDE` has not been calculated") + + if remove_background: + if self.background is not None: + kde[self.background] = 0 + else: + raise ValueError("`background` is undefined") + + if crop is not None: + kde = kde[tuple(slice(*c) for c in crop)] + + fig, ax = plt.subplots(1, 1) + assert isinstance(ax, Axes) + ax.hist(coo_array(kde).data, **kwargs) + ax.set(xlabel=f"KDE of {name}", ylabel="# pixels") + return fig + + def plot_genecount( + self, + *, + gene: str | None = None, + crop: _RangeTuple2D | None = None, + scalebar: bool = True, + im_kwargs: dict = dict(), + scalebar_kwargs: dict = _SCALEBAR, + ) -> Figure: + """ + Plot the gene expression counts. + + By default this will plot the :py:attr:`sainsc.LazyKDE.total_mRNA`. If + `gene` is specified the respective gene will be plotted. + + Parameters + ---------- + gene : str, optional + Gene in :py:attr:`sainsc.LazyKDE.genes` to use for plotting. + crop : tuple[tuple[int, int], tuple[int, int]], optional + Coordinates to crop the data defined as `((xmin, xmax), (ymin, ymax))`. + scalebar : bool, optional + If `True`, add a ``matplotlib_scalebar.scalebar.ScaleBar`` to the plot. + im_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to :py:func:`matplotlib.pyplot.imshow`. + scalebar_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to ``matplotlib_scalebar.scalebar.ScaleBar``. + + Returns + ------- + matplotlib.figure.Figure + + Raises + ------ + ValueError + If :py:attr:`sainsc.LazyKDE.total_mRNA` has not been calculated. + """ + if gene is not None: + img = self.counts[gene].toarray() + else: + if self.total_mRNA is not None: + img = self.total_mRNA + else: + raise ValueError( + "`total_mRNA` has not been calculated." + "Run `calculate_total_mRNA` first." + ) + title = "total mRNA" if gene is None else gene + + return self._plot_2d( + img, + title, + crop=crop, + scalebar=scalebar, + im_kwargs=im_kwargs, + scalebar_kwargs=scalebar_kwargs, + ) + + def plot_KDE( + self, + *, + gene: str | None = None, + remove_background: bool = False, + crop: _RangeTuple2D | None = None, + scalebar: bool = True, + im_kwargs: dict = dict(), + scalebar_kwargs: dict = _SCALEBAR, + ) -> Figure: + """ + Plot the kernel density estimate (KDE). + + By default this will plot the KDE of the total mRNA + (:py:attr:`sainsc.LazyKDE.total_mRNA_KDE`). If `gene` is specified the + respective KDE will be computed and plotted. + + Parameters + ---------- + gene : str, optional + Gene in :py:attr:`sainsc.LazyKDE.genes` to use for plotting. + remove_background : bool, optional + If `True`, all pixels for which :py:attr:`sainsc.LazyKDE.background` is + `False` are set to 0. + crop : tuple[tuple[int, int], tuple[int, int]], optional + Coordinates to crop the data defined as `((xmin, xmax), (ymin, ymax))`. + scalebar : bool, optional + If `True`, add a ``matplotlib_scalebar.scalebar.ScaleBar`` to the plot. + im_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to :py:func:`matplotlib.pyplot.imshow`. + scalebar_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to ``matplotlib_scalebar.scalebar.ScaleBar``. + + Returns + ------- + matplotlib.figure.Figure + + Raises + ------ + ValueError + If :py:attr:`sainsc.LazyKDE.total_mRNA_KDE` has not been calculated. + + See Also + -------- + :py:meth:`sainsc.LazyKDE.kde` + """ + if gene is not None: + img = self.kde(gene).toarray() + else: + if self.total_mRNA_KDE is not None: + img = self.total_mRNA_KDE + else: + raise ValueError( + "`total_mRNA_KDE` has not been calculated." + "Run `calculate_total_mRNA_KDE` first." + ) + + title = "KDE of " + ("total mRNA" if gene is None else gene) + + return self._plot_2d( + img, + title, + remove_background=remove_background, + crop=crop, + scalebar=scalebar, + im_kwargs=im_kwargs, + scalebar_kwargs=scalebar_kwargs, + ) + + def plot_local_maxima( + self, + *, + crop: _RangeTuple2D | None = None, + background_kwargs: dict = dict(), + scatter_kwargs: dict = dict(), + ) -> Figure: + """ + Plot the local kernel density estimate maxima. + + Parameters + ---------- + crop : tuple[tuple[int, int], tuple[int, int]], optional + Coordinates to crop the data defined as `((xmin, xmax), (ymin, ymax))`. + background_kwargs : dict, optional + Keyword arguments that are passed to :py:meth:`sainsc.LazyKDE.plot_KDE`. + scatter_kwargs : dict, optional + Keyword arguments that are passed to :py:func:`matplotlib.pyplot.scatter`. + + Returns + ------- + matplotlib.figure.Figure + + See Also + -------- + :py:meth:`sainsc.LazyKDE.find_local_maxima` + """ + if self.local_maxima is None: + raise ValueError + + x, y = self.local_maxima + + if crop is not None: + x_min, x_max = crop[0] + y_min, y_max = crop[1] + keep = (x >= x_min) & (y >= y_min) & (x < x_max) & (y < y_max) + x = x[keep] - x_min + y = y[keep] - y_min + + fig = self.plot_KDE(crop=crop, **background_kwargs) + fig.axes[0].scatter(x, y, **scatter_kwargs) + return fig + + def plot_celltypemap( + self, + *, + remove_background: bool = True, + crop: _RangeTuple2D | None = None, + scalebar: bool = True, + cmap: _Cmap = "hls", + scalebar_kwargs: dict = _SCALEBAR, + ) -> Figure: + """ + Plot the cell-type annotation. + + Parameters + ---------- + remove_background : bool, optional + If `True`, all pixels for which :py:attr:`sainsc.LazyKDE.background` is + `False` are set to 0. + crop : tuple[tuple[int, int], tuple[int, int]], optional + Coordinates to crop the data defined as `((xmin, xmax), (ymin, ymax))`. + scalebar : bool, optional + If `True`, add a ``matplotlib_scalebar.scalebar.ScaleBar`` to the plot. + cmap : str or list or dict, optional + If it is a string it must be the name of a `cmap` that can be used in + :py:func:`seaborn.color_palette`. + If it is a list of colors it must have the same length as the number of + celltypes. + If it is a dictionary it must be a mapping from celltpye to color. Undefined + celltypes are plotted as `'grey'`. + Colors can either be provided as string that can be converted via + :py:func:`matplotlib.colors.to_rgb` or as ``(r, g, b)``-tuple between 0-1. + scalebar_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to ``matplotlib_scalebar.scalebar.ScaleBar``. + + Returns + ------- + matplotlib.figure.Figure + + + See Also + -------- + :py:meth:`sainsc.LazyKDE.assign_celltype` + """ + if self.celltypes is None or self.celltype_map is None: + raise ValueError("celltype assignment missing") + + n_celltypes = len(self.celltypes) + + celltype_map = self.celltype_map + if remove_background: + if self.background is None: + raise ValueError("Background has not been filtered.") + else: + celltype_map[self.background] = -1 + + if crop is not None: + celltype_map = celltype_map[tuple(slice(*c) for c in crop)] + + if isinstance(cmap, str): + color_map = sns.color_palette(cmap, n_colors=n_celltypes) + assert isinstance(color_map, Iterable) + else: + if isinstance(cmap, list): + if len(cmap) != n_celltypes: + raise ValueError("You need to provide 1 color per celltype") + + elif isinstance(cmap, dict): + cmap = [cmap.get(cell, "grey") for cell in self.celltypes] + + color_map = [to_rgb(c) if isinstance(c, str) else c for c in cmap] + + # convert to uint8 to reduce memory of final image + color_map_int = tuple( + (np.array(c) * 255).round().astype(np.uint8) for c in color_map + ) + img = _apply_color(celltype_map.T, color_map_int) + + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + label=lbl, + markerfacecolor=c, + markersize=10, + ) + for c, lbl in zip(color_map, self.celltypes) + ] + + fig, ax = plt.subplots() + assert isinstance(ax, Axes) + ax.imshow(img, origin="lower") + ax.legend( + title="Cell type", + handles=legend_elements, + ncols=-(n_celltypes // -20), # ceildiv + loc="center left", + bbox_to_anchor=(1, 0.5), + ) + + if scalebar: + self._add_scalebar(ax, **scalebar_kwargs) + + return fig + + def plot_cosine_similarity( + self, + *, + remove_background: bool = False, + crop: _RangeTuple2D | None = None, + scalebar: bool = True, + im_kwargs: dict = dict(), + scalebar_kwargs: dict = _SCALEBAR, + ) -> Figure: + """ + Plot the cosine similarity from cell-type assignment. + + Parameters + ---------- + remove_background : bool, optional + If `True`, all pixels for which :py:attr:`sainsc.LazyKDE.background` is + `False` are set to 0. + crop : tuple[tuple[int, int], tuple[int, int]], optional + Coordinates to crop the data defined as `((xmin, xmax), (ymin, ymax))`. + scalebar : bool, optional + If `True`, add a ``matplotlib_scalebar.scalebar.ScaleBar`` to the plot. + im_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to :py:func:`matplotlib.pyplot.imshow`. + scalebar_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to ``matplotlib_scalebar.scalebar.ScaleBar``. + + Returns + ------- + matplotlib.figure.Figure + + See Also + -------- + :py:meth:`sainsc.LazyKDE.assign_celltype` + """ + if self.cosine_similarity is not None: + img = self.cosine_similarity + else: + raise ValueError("Cell types have not been assigned") + + return self._plot_2d( + img, + "Cosine similarity", + remove_background=remove_background, + crop=crop, + scalebar=scalebar, + im_kwargs=im_kwargs, + scalebar_kwargs=scalebar_kwargs, + ) + + ## Attributes + @property + def n_threads(self) -> int: + """ + int: Number of threads that will be used for computations. + + Raises + ------ + TypeError + If setting with a type other than `int` or less than 0. + """ + return self._threads + + @n_threads.setter + def n_threads(self, n_threads: int): + if isinstance(n_threads, int) and n_threads >= 0: + if n_threads == 0: + n_threads = _get_n_cpus() + self._threads = n_threads + self.counts.n_threads = n_threads + else: + raise TypeError("`n_threads` must be an `int` >= 0.") + + @property + def shape(self) -> tuple[int, int]: + """ + tuple[int, int]: Shape of the sample. + """ + return self.counts.shape + + @property + def genes(self) -> list[str]: + """ + list[str]: List of genes. + """ + return self.counts.genes() + + @property + def resolution(self) -> float | None: + """ + float: Resolution in nm / pixel. + + Raises + ------ + TypeError + If setting with a type other than `float` or `int`. + """ + return self.counts.resolution + + @resolution.setter + def resolution(self, resolution: float): + self.counts.resolution = resolution + + @property + def kernel(self) -> np.ndarray: + """ + numpy.ndarray: Map of the KDE of total mRNA. + + Raises + ------ + ValueError + If kernel is not a square, 2D :py:class:`numpy.ndarray` of uneven length. + """ + return self._kernel.copy() + + @kernel.setter + def kernel(self, kernel: np.ndarray): + if ( + len(kernel.shape) != 2 + or kernel.shape[0] != kernel.shape[1] + or any(i % 2 == 0 for i in kernel.shape) + ): + raise ValueError( + "`kernel` currently only supports 2D squared arrays of uneven length." + ) + else: + self._kernel = kernel.astype(np.float32) + + @property + def local_maxima(self) -> _Local_Max | None: + """ + tuple[numpy.ndarray[numpy.signedinteger], ...]: Coordinates of local maxima. + """ + return self._local_maxima + + @property + def total_mRNA(self) -> NDArray[np.unsignedinteger] | None: + """ + numpy.ndarray[numpy.unsignedinteger]: Map of the total mRNA. + """ + return self._total_mRNA + + @property + def total_mRNA_KDE(self) -> NDArray[np.single] | None: + """ + numpy.ndarray[numpy.single]: Map of the KDE of total mRNA. + """ + return self._total_mRNA_KDE + + @property + def background(self) -> NDArray[np.bool] | None: + """ + numpy.ndarray[numpy.bool]: Map of pixels that are assigned as background. + + Raises + ------ + TypeError + If setting with array that is not of type `numpy.bool`. + ValueError + If setting with array that has different shape than `self`. + """ + return self._background + + @background.setter + def background(self, background: NDArray[np.bool]): + if background.shape != self.shape: + raise ValueError("`background` must have same shape as `self`") + else: + self._background = background + + @property + def celltypes(self) -> list[str] | None: + """ + list[str]: List of assigned celltypes. + """ + return self._celltypes + + @property + def cosine_similarity(self) -> NDArray[np.single] | None: + """ + numpy.ndarray[numpy.single]: Cosine similarity for each pixel. + """ + return self._cosine_similarity + + @property + def celltype_map(self) -> NDArray[np.signedinteger] | None: + """ + numpy.ndarray[numpy.signedinteger]: Cell-type map of cell-type indices. + + Each number corresponds to the index in :py:attr:`sainsc.LazyKDE.celltypes`, + and -1 to unassigned (background). + """ + return self._celltype_map + + def __str__(self) -> str: + repr = [ + f"LazyKDE ({self.n_threads} threads)", + f"genes: {len(self.genes)}", + f"shape: {self.shape}", + ] + if self.resolution is not None: + repr.append(f"resolution: {self.resolution} nm / px") + if self.background is not None: + repr.append("background: set") + if self.local_maxima is not None: + repr.append(f"local maxima: {len(self.local_maxima[0])}") + if self.celltypes is not None: + repr.append(f"celltypes: {len(self.celltypes)}") + + spacing = " " + + return f"\n{spacing}".join(repr) diff --git a/sainsc/lazykde/__init__.py b/sainsc/lazykde/__init__.py new file mode 100644 index 0000000..acad7e7 --- /dev/null +++ b/sainsc/lazykde/__init__.py @@ -0,0 +1,4 @@ +from ._kernel import epanechnikov_kernel, gaussian_kernel +from ._LazyKDE import LazyKDE + +__all__ = ["LazyKDE", "epanechnikov_kernel", "gaussian_kernel"] diff --git a/sainsc/lazykde/_kernel.py b/sainsc/lazykde/_kernel.py new file mode 100644 index 0000000..c3686a8 --- /dev/null +++ b/sainsc/lazykde/_kernel.py @@ -0,0 +1,103 @@ +import math +from typing import TypeVar + +import numpy as np +from numpy.typing import DTypeLike, NDArray +from scipy import ndimage, signal + +T = TypeVar("T", bound=np.number) + + +def _make_circular_mask(radius: int) -> NDArray[np.bool_]: + diameter = radius * 2 + 1 + x, y = np.ogrid[:diameter, :diameter] + dist_from_center = np.sqrt((x - radius) ** 2 + (y - radius) ** 2) + return dist_from_center <= radius + + +def _make_circular_kernel(kernel: NDArray[T], radius: int) -> NDArray[T]: + kernel[~_make_circular_mask(radius)] = 0 + return kernel + + +def gaussian_kernel( + sigma: float, + radius: int, + *, + dtype: DTypeLike = np.float32, + circular: bool = False, + **kwargs, +) -> NDArray: + """ + Generate a 2D Gaussian kernel array. + + Parameters + ---------- + sigma : float + Bandwidth of the Gaussian. + radius : int + Radius of the kernel. Output size will be :math:`2*radius+1`. + dtype : numpy.typing.DTypeLike + Datatype of the kernel. + circular : bool, optional + Whether to make kernel circular. Values outside `radius` will be set to 0. + kwargs : + Other keyword arguments will be passed to + :py:func:`scipy.ndimage.gaussian_filter`. + + Returns + ------- + numpy.ndarray + """ + mask_size = 2 * radius + 1 + + dirac = signal.unit_impulse((mask_size, mask_size), idx="mid") + + gaussian_kernel = ndimage.gaussian_filter( + dirac, sigma, output=np.float64, **kwargs + ).astype(dtype) + + if circular: + gaussian_kernel = _make_circular_kernel(gaussian_kernel, radius) + + return gaussian_kernel + + +def epanechnikov_kernel(sigma: float, *, dtype: DTypeLike = np.float32) -> np.ndarray: + """ + Generate a 2D Epanechnikov kernel array. + + :math:`K(x) = 1/2 * c_d^{-1}*(d+2)(1-||x||^2)` if :math:`||x|| < 1` else 0, + where :math:`d` is the number of dimensions + and :math:`c_d` the volume of the unit `d`-dimensional sphere. + + Parameters + ---------- + sigma : float + Bandwidth of the kernel. + dtype : numpy.typing.DTypeLike, optional + Datatype of the kernel. + + Returns + ------- + numpy.ndarray + """ + # https://doi.org/10.1109/CVPR.2000.854761 + # c_d = pi for d=2 + + r = math.ceil(sigma) + dia = 2 * r - 1 # values at r are zero anyways so the kernel matrix can be smaller + + # 1/2 * pi^-1 * (d+2) + scale = 2 / math.pi + + kernel = np.zeros((dia, dia), dtype=dtype) + for i in range(dia): + for j in range(dia): + x = i - r + 1 + y = j - r + 1 + norm = (x / sigma) ** 2 + (y / sigma) ** 2 + if norm < 1: + kernel[i, j] = scale * (1 - norm) + + return kernel diff --git a/sainsc/lazykde/_utils.py b/sainsc/lazykde/_utils.py new file mode 100644 index 0000000..3a8723f --- /dev/null +++ b/sainsc/lazykde/_utils.py @@ -0,0 +1,81 @@ +from collections.abc import Iterable +from typing import Protocol, TypeVar + +import numpy as np +import pandas as pd +from anndata import AnnData +from numba import njit +from numpy.typing import NDArray +from scipy.sparse import csr_matrix, sparray, spmatrix +from skimage.measure import label, regionprops + +from sainsc._utils_rust import GridCounts + +from .._utils import _get_coordinate_index + +T = TypeVar("T", bound=np.number) +U = TypeVar("U", bound=np.bool_ | np.integer) + +_SCALEBAR = {"units": "nm", "box_alpha": 0, "color": "w"} + + +@njit +def _apply_color( + img_in: NDArray[np.integer], cmap: tuple[NDArray[T], ...] +) -> NDArray[T]: + img = np.zeros(shape=(*img_in.shape, 3), dtype=cmap[0].dtype) + for i in range(img_in.shape[0]): + for j in range(img_in.shape[1]): + if img_in[i, j] >= 0: + img[i, j, :] = cmap[img_in[i, j]] + return img + + +def _get_cell_dtype(n: int) -> np.dtype: + return np.result_type("int8", n) + + +def _filter_blobs(labeled_map: NDArray[U], min_blob_area: int) -> NDArray[U]: + # remove small blops (i.e. "cells") + if min_blob_area <= 0: + raise ValueError("Area must be bigger than 0") + blob_labels = label(labeled_map, background=0) + for blop in regionprops(blob_labels): + if blop.area_filled < min_blob_area: + min_x, min_y, max_x, max_y = blop.bbox + labeled_map[min_x:max_x, min_y:max_y][blop.image] = 0 + return labeled_map + + +def _localmax_anndata( + kde: spmatrix | sparray | NDArray, + genelist: Iterable[str], + coord: tuple[NDArray[np.integer], ...], + *, + name: str | None = None, + n_threads: int = 1, +) -> AnnData: + obs = pd.DataFrame( + index=_get_coordinate_index(*coord, name=name, n_threads=n_threads) + ) + + return AnnData( + X=csr_matrix(kde), + obs=obs, + var=pd.DataFrame(index=pd.Index(genelist, name="gene")), + obsm={"spatial": np.column_stack(coord)}, + ) + + +class CosineCelltypeCallable(Protocol): + def __call__( + self, + counts: GridCounts, + genes: list[str], + signatures: NDArray[np.float32], + kernel: NDArray[np.float32], + *, + log: bool = ..., + chunk_size: tuple[int, int] = ..., + n_threads: int | None = ..., + ) -> tuple[NDArray[np.float32], NDArray[np.signedinteger]]: ... diff --git a/sainsc/py.typed b/sainsc/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/coordinates.rs b/src/coordinates.rs new file mode 100644 index 0000000..cfbc3bd --- /dev/null +++ b/src/coordinates.rs @@ -0,0 +1,165 @@ +use crate::utils::create_pool; +use indexmap::IndexMap; +use ndarray::{Array1, Array2, ArrayView1, Zip}; +use num::{one, zero, PrimInt, Zero}; +use numpy::{IntoPyArray, PyArray1, PyArray2, PyFixedString, PyReadonlyArray1}; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use rayon::{prelude::ParallelIterator, ThreadPoolBuildError}; +use std::{fmt::Display, hash::Hash, ops::AddAssign}; + +type CoordInt = i32; +type CodeInt = i32; + +#[pyfunction] +#[pyo3(signature = (x, y, *, n_threads=None))] +/// Concatenate two int arrays into a string separated by underscore +pub fn coordinate_as_string<'py>( + py: Python<'py>, + x: PyReadonlyArray1<'py, CoordInt>, + y: PyReadonlyArray1<'py, CoordInt>, + n_threads: Option, +) -> PyResult>>> { + match string_coordinate_index_(x.as_array(), y.as_array(), n_threads.unwrap_or(0)) { + Ok(string_coordinates) => Ok(string_coordinates + .map(|s| (*s).into()) + .into_pyarray_bound(py)), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } +} + +#[pyfunction] +#[pyo3(signature = (x, y, *, n_threads=None))] +/// From a list of coordinates extract a categorical representation +pub fn categorical_coordinate<'py>( + py: Python<'py>, + x: PyReadonlyArray1<'py, CoordInt>, + y: PyReadonlyArray1<'py, CoordInt>, + n_threads: Option, +) -> PyResult<( + Bound<'py, PyArray1>, + Bound<'py, PyArray2>, +)> { + match categorical_coordinate_(x.as_array(), y.as_array(), n_threads.unwrap_or(0)) { + Ok((codes, coordinates)) => Ok(( + codes.into_pyarray_bound(py), + coordinates.into_pyarray_bound(py), + )), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } +} + +//// pure Rust part + +/// Concatenate two int arrays into a 'string' array +fn string_coordinate_index_<'a, X, const N: usize>( + x: ArrayView1<'a, X>, + y: ArrayView1<'a, X>, + n_threads: usize, +) -> Result, ThreadPoolBuildError> +where + X: Display, + &'a X: Send, + ArrayView1<'a, X>: Sync + Send, +{ + let thread_pool = create_pool(n_threads)?; + + let mut x_y = Array1::from_elem(x.len(), [0u8; N]); + thread_pool.install(|| { + Zip::from(&mut x_y).and(x).and(y).par_for_each(|xy, x, y| { + let mut xy_string = String::with_capacity(N); + xy_string.push_str(&x.to_string()); + xy_string.push('_'); + xy_string.push_str(&y.to_string()); + + let xy_bytes = xy_string.as_bytes(); + (*xy)[..xy_bytes.len()].copy_from_slice(xy_bytes); + }); + }); + + Ok(x_y) +} + +fn categorical_coordinate_<'a, C, X>( + x: ArrayView1<'a, X>, + y: ArrayView1<'a, X>, + n_threads: usize, +) -> Result<(Array1, Array2), ThreadPoolBuildError> +where + C: PrimInt + Sync + Send + AddAssign, + X: Copy + Zero + Sync + Send, + (X, X): Eq + PartialEq + Hash, + &'a X: Send, + ArrayView1<'a, X>: Sync + Send, +{ + let thread_pool = create_pool(n_threads)?; + + let n = x.len(); + let n_coord_estimate = n / 5; // rough guess of size to reduce allocations + + let mut coord2idx: IndexMap<(X, X), C> = IndexMap::with_capacity(n_coord_estimate); + + { + let mut cnt = zero::(); + Zip::from(x).and(y).for_each(|x, y| { + coord2idx.entry((*x, *y)).or_insert_with(|| { + let curr = cnt; + cnt += one(); + curr + }); + }); + } + + let codes = thread_pool.install(|| { + Zip::from(x).and(y).par_map_collect(|x, y| { + *coord2idx + .get(&(*x, *y)) + .expect("All coordinates are initialized in HashMap") + }) + }); + + let coordinates = thread_pool.install(|| { + coord2idx + .par_keys() + .map(|row| [row.0, row.1]) + .collect::>() + .into() + }); + + Ok((codes, coordinates)) +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::array; + + #[test] + fn test_string_coordinate() { + let a = array![0, 1, 99_999]; + let b = array![0, 20, 99_999]; + + let a_b: Array1<[u8; 12]> = string_coordinate_index_(a.view(), b.view(), 1).unwrap(); + + let a_b_string: Vec<[u8; 12]> = vec![ + [48, 95, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [49, 95, 50, 48, 0, 0, 0, 0, 0, 0, 0, 0], + [57, 57, 57, 57, 57, 95, 57, 57, 57, 57, 57, 0], + ]; + assert_eq!(a_b, Array1::from_vec(a_b_string)); + } + + #[test] + fn test_categorical_coordinate() { + let a = array![0, 0, 1, 0, 1]; + let b = array![0, 1, 0, 0, 1]; + + let (codes, coord): (Array1, Array2) = + categorical_coordinate_(a.view(), b.view(), 1).unwrap(); + + let codes_test = array![0, 1, 2, 0, 3]; + let coord_test = array![[0, 0], [0, 1], [1, 0], [1, 1]]; + + assert_eq!(codes, codes_test); + assert_eq!(coord, coord_test); + } +} diff --git a/src/cosine.rs b/src/cosine.rs new file mode 100644 index 0000000..d1e3e01 --- /dev/null +++ b/src/cosine.rs @@ -0,0 +1,304 @@ +use crate::gridcounts::GridCounts; +use crate::sparsekde::sparse_kde_csx_; +use crate::utils::create_pool; +use ndarray::{ + concatenate, s, Array, Array2, Array3, ArrayView2, Axis, NdFloat, NewAxis, ShapeError, Slice, + Zip, +}; +use ndarray_stats::QuantileExt; +use num::{one, zero, NumCast, PrimInt, Signed}; +use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2}; +use pyo3::{exceptions::PyValueError, prelude::*}; +use rayon::prelude::*; +use sprs::{CompressedStorage::CSR, CsMatI, CsMatViewI, SpIndex}; +use std::{ + cmp::{max, min}, + error::Error, + ops::Range, +}; + +macro_rules! build_cos_ct_fn { + ($name:tt, $t_cos:ty, $t_ct:ty) => { + #[pyfunction] + #[pyo3(signature = (counts, genes, signatures, kernel, *, log=false, chunk_size=(500, 500), n_threads=None))] + /// calculate cosine similarity and assign celltype + pub fn $name<'py>( + py: Python<'py>, + counts: &mut GridCounts, + genes: Vec, + signatures: PyReadonlyArray2<'py, $t_cos>, + kernel: PyReadonlyArray2<'py, $t_cos>, + log: bool, + chunk_size: (usize, usize), + n_threads: Option, + ) -> PyResult<(Bound<'py, PyArray2<$t_cos>>, Bound<'py, PyArray2<$t_ct>>)> { + + // ensure that all count arrays are CSR + counts.to_format(CSR); + let gene_counts: Vec<_> = genes + .iter() + .map(|g| { + counts + .get_view(g) + .ok_or(PyValueError::new_err("Not all genes exist")) + }) + .collect::>()?; + + let cos_ct = chunk_and_calculate_cosine( + &gene_counts, + signatures.as_array(), + kernel.as_array(), + counts.shape, + log, + chunk_size, + n_threads.unwrap_or(0), + ); + + match cos_ct { + Ok((cosine, celltype_map)) => Ok(( + cosine.into_pyarray_bound(py), + celltype_map.into_pyarray_bound(py), + )), + Err(e) => Err(PyValueError::new_err(e.to_string())), + } + } + }; +} + +build_cos_ct_fn!(cosinef32_and_celltypei8, f32, i8); +build_cos_ct_fn!(cosinef32_and_celltypei16, f32, i16); + +fn chunk_and_calculate_cosine<'a, C, I, F, U>( + counts: &[CsMatViewI<'a, C, I>], + signatures: ArrayView2<'a, F>, + kernel: ArrayView2<'a, F>, + shape: (usize, usize), + log: bool, + chunk_size: (usize, usize), + n_threads: usize, +) -> Result<(Array2, Array2), Box> +where + C: NumCast + Copy + Sync + Send + Default, + I: SpIndex + Signed + Sync + Send, + F: NdFloat, + U: PrimInt + Signed + Sync + Send, + Slice: From>, +{ + let pool = create_pool(n_threads)?; + + let kernelsize = kernel.shape(); + + let (nrow, ncol) = shape; + let (srow, scol) = chunk_size; + let (padrow, padcol) = ((kernelsize[0] - 1) / 2, (kernelsize[1] - 1) / 2); + let (m, n) = (nrow.div_ceil(srow), ncol.div_ceil(scol)); // number of chunks + + let mut cosine_rows = Vec::with_capacity(m); + let mut celltype_rows = Vec::with_capacity(m); + + pool.install(|| { + for i in 0..m { + let (slice_row, unpad_row) = chunk_(i, srow, nrow, padrow); + let row_chunk: Vec<_> = counts + .par_iter() + .map(|c| { + c.slice_outer(slice_row.clone()) + .transpose_view() + .to_other_storage() + }) + .collect(); + + let (cosine_cols, celltype_cols): (Vec>, Vec>) = (0..n) + .into_par_iter() + .map(|j| { + let (slice_col, unpad_col) = chunk_(j, scol, ncol, padcol); + + let chunk = row_chunk + .par_iter() + .map(|c| c.slice_outer(slice_col.clone()).transpose_into().to_owned()) + .collect(); + + cosine_and_celltype_( + chunk, + signatures, + kernel, + (unpad_row.clone(), unpad_col), + log, + ) + }) + .unzip(); + cosine_rows.push(concat_1d(cosine_cols, 1)); + celltype_rows.push(concat_1d(celltype_cols, 1)); + } + }); + + let cosine = concat_1d(cosine_rows.into_iter().collect::, _>>()?, 0)?; + let celltype = concat_1d(celltype_rows.into_iter().collect::, _>>()?, 0)?; + Ok((cosine, celltype)) +} + +fn concat_1d( + chunks: Vec>, + axis: usize, +) -> Result, ShapeError> { + concatenate( + Axis(axis), + &chunks.par_iter().map(|a| a.view()).collect::>(), + ) +} + +fn chunk_(i: usize, step: usize, n: usize, pad: usize) -> (Range, Range) { + let bound1 = (i * step) as isize; + let bound2 = (i + 1) * step; + let start = max(0, bound1 - pad as isize) as usize; + let start2 = max(0, bound1 - start as isize) as usize; + let chunk_pad = start..min(n, bound2 + pad); + let chunk_unpad = start2..(start2 + min(step, (n as isize - bound1) as usize)); + (chunk_pad, chunk_unpad) +} + +fn cosine_and_celltype_<'a, C, I, F, U>( + counts: Vec>, + signatures: ArrayView2<'a, F>, + kernel: ArrayView2<'a, F>, + unpad: (Range, Range), + log: bool, +) -> (Array2, Array2) +where + C: NumCast + Copy, + F: NdFloat, + U: PrimInt + Signed, + I: SpIndex + Signed, + Slice: From>, +{ + let (unpad_r, unpad_c) = unpad; + let mut csx_weights_iter = counts + .into_iter() + .zip(signatures.rows()) + .filter(|(csx, _)| csx.nnz() > 0); + + match csx_weights_iter.next() { + // fastpath if all csx are empty + None => { + let shape = (unpad_r.end - unpad_r.start, unpad_c.end - unpad_c.start); + (Array2::zeros(shape), Array2::from_elem(shape, -one::())) + } + Some((csx, weights)) => { + let shape = csx.shape(); + let mut kde = Array2::zeros(shape); + + sparse_kde_csx_(&mut kde, csx.view(), kernel); + if log { + kde.mapv_inplace(F::ln_1p); + } + + let mut kde_norm = kde + .slice(s![unpad_r.clone(), unpad_c.clone()]) + .map(|k| k.powi(2)); + let mut cosine: Array3 = &kde.slice(s![NewAxis, unpad_r.clone(), unpad_c.clone()]) + * &weights.slice(s![.., NewAxis, NewAxis]); + + for (csx, weights) in csx_weights_iter { + sparse_kde_csx_(&mut kde, csx.view(), kernel); + let mut kde_unpadded = kde.slice_mut(s![unpad_r.clone(), unpad_c.clone()]); + if log { + kde_unpadded.mapv_inplace(F::ln_1p); + } + + Zip::from(&mut kde_norm) + .and(&kde_unpadded) + .for_each(|n, &k| *n += k.powi(2)); + + cosine + .outer_iter_mut() + .zip(&weights) + .for_each(|(mut cos, &w)| cos += &kde_unpadded.map(|&x| x * w)); + } + // TODO: write to zarr + get_max_cosine_and_celltype(cosine, kde_norm) + } + } +} + +fn get_max_cosine_and_celltype( + cosine: Array3, + kde_norm: Array2, +) -> (Array2, Array2) +where + I: PrimInt + Signed, + F: NdFloat, +{ + let (mut max_cosine, mut celltypemap) = get_max_argmax(&cosine); + + Zip::from(&mut celltypemap) + .and(&mut max_cosine) + .and(&kde_norm) + .for_each(|ct, cos, &norm| { + if norm == zero() { + *ct = -one::(); + } else { + *cos /= norm.sqrt(); + } + }); + + (max_cosine, celltypemap) +} + +pub fn get_max_argmax( + array: &Array3, +) -> (Array2, Array2) { + let argmax = array.map_axis(Axis(0), |view| view.argmax().unwrap()); + let max = Array::from_shape_fn(argmax.raw_dim(), |(i, j)| array[[argmax[[i, j]], i, j]]); + // let max = array.map_axis(Axis(0), |view| view.max().unwrap().clone()); + (max, argmax.mapv(|i| I::from(i).unwrap())) +} + +#[cfg(test)] +mod tests { + + use super::*; + use ndarray::array; + + struct Setup { + cosine: Array3, + norm: Array2, + max: Array2, + argmax: Array2, + cos: Array2, + celltype: Array2, + } + + impl Setup { + fn new() -> Self { + Self { + cosine: array![[[1.0, 0.0, 0.0]], [[0.5, 1.0, 0.0]]], + norm: array![[4.0, 1.0, 0.0]], + max: array![[1.0, 1.0, 0.0]], + argmax: array![[0, 1, 0]], + cos: array![[0.5, 1.0, 0.0]], + celltype: array![[0, 1, -1]], + } + } + } + + #[test] + fn test_max_argmax() { + let setup = Setup::new(); + + let max_argmax: (Array2, Array2) = get_max_argmax(&setup.cosine); + + assert_eq!(max_argmax.0, setup.max); + assert_eq!(max_argmax.1, setup.argmax); + } + + #[test] + fn test_get_max_cosine_and_celltype() { + let setup = Setup::new(); + + let cos_ct: (Array2, Array2) = + get_max_cosine_and_celltype(setup.cosine, setup.norm); + + assert_eq!(cos_ct.0, setup.cos); + assert_eq!(cos_ct.1, setup.celltype); + } +} diff --git a/src/gridcounts.rs b/src/gridcounts.rs new file mode 100644 index 0000000..c1f64d3 --- /dev/null +++ b/src/gridcounts.rs @@ -0,0 +1,426 @@ +use crate::sparsearray_conversion::WrappedCsx; +use crate::utils::create_pool; +use bincode::{deserialize, serialize}; +use itertools::MultiUnzip; +use ndarray::Array2; +use num::Zero; +use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2}; +use polars::{ + datatypes::DataType::{Int32, UInt32}, + prelude::*, +}; +use pyo3::{ + exceptions::{PyKeyError, PyRuntimeError, PyValueError}, + prelude::*, + types::{PyBytes, PyType}, +}; +use pyo3_polars::PyDataFrame; +use rayon::{ + iter::{ + IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, + }, + join, ThreadPool, +}; +use sprs::{ + CompressedStorage::{self, CSC, CSR}, + CsMatI, CsMatViewI, SpIndex, TriMatI, +}; +use std::{ + cmp::min, + collections::{HashMap, HashSet}, + ops::AddAssign, +}; + +/// Class implementation + +pub type Count = u32; +pub type CsxIndex = i32; + +#[pyclass(mapping, module = "sainsc")] +pub struct GridCounts { + counts: HashMap>, + #[pyo3(get)] + pub shape: (usize, usize), + #[pyo3(get)] + pub resolution: Option, + #[pyo3(get)] + pub n_threads: usize, + threadpool: ThreadPool, +} + +impl GridCounts { + pub fn get_view(&self, gene: &String) -> Option> { + self.counts.get(gene).map(|x| x.view()) + } + + pub fn to_format(&mut self, format: CompressedStorage) { + self.threadpool.install(|| { + self.counts.par_iter_mut().for_each(|(_, v)| { + if format != v.storage() { + *v = v.to_other_storage() + } + }) + }); + } +} + +#[pymethods] +impl GridCounts { + #[new] + #[pyo3(signature = (counts, *, resolution=None, n_threads=None))] + fn new( + counts: HashMap>, + resolution: Option, + n_threads: Option, + ) -> PyResult { + let n_threads = n_threads.unwrap_or(0); + let threadpool = + create_pool(n_threads).map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let counts: HashMap<_, _> = + threadpool.install(|| counts.into_par_iter().map(|(k, v)| (k, v.0)).collect()); + + let shape = if counts.is_empty() { + (0, 0) + } else { + let shapes: Vec<_> = counts.values().map(|v| v.shape()).collect(); + + if !shapes.windows(2).all(|w| w[0] == w[1]) { + return Err(PyValueError::new_err( + "All sparse arrays must have same shape", + )); + } + + *shapes.first().expect("Length is non-zero") + }; + + Ok(Self { + counts, + shape, + resolution, + n_threads, + threadpool, + }) + } + + #[classmethod] + #[pyo3(signature = (df, *, resolution=None, binsize=None, n_threads=None))] + fn from_dataframe( + _cls: &Bound<'_, PyType>, + df: PyDataFrame, + resolution: Option, + binsize: Option, + n_threads: Option, + ) -> PyResult { + fn _from_dataframe( + mut df: DataFrame, + binsize: Option, + ) -> Result< + ( + HashMap>, + (usize, usize), + ), + PolarsError, + > { + fn col_as_nonull_vec( + df: &DataFrame, + col: &str, + f: F, + ) -> Result::Native>, PolarsError> + where + F: Fn(&Series) -> Result<&ChunkedArray, PolarsError>, + T: PolarsNumericType, + { + Ok(f(df.column(col)?)? + .to_vec_null_aware() + .expect_left(&format!("{col} should have no null"))) + } + + // bin if binsize is provided + if let Some(bin) = binsize { + df.with_column(df.column("x")? / bin)?; + df.with_column(df.column("y")? / bin)?; + } + + // cast to correct dtypes and shift (i.e. subtract min) + for col in ["x", "y"] { + let s = df.column(col)?.strict_cast(&Int32)?; + df.with_column(&s - s.min::()?.expect("non-null"))?; + } + + match df.column("count") { + // if counts does not exist use all 1s + Err(_) => df.with_column(Series::new("count", vec![1u32; df.height()]))?, + Ok(s) => df.with_column(s.strict_cast(&UInt32)?)?, + }; + + // if df.column("gene")?.unpack()?.dtype() != Categorical { + // df.with_column(df.column("gene")?.cast(&Categorical)?); + // } + + let shape = ( + df.column("x")?.max::()?.expect("non-null") + 1, + df.column("y")?.max::()?.expect("non-null") + 1, + ); + + let counts_dict = df + .partition_by(["gene"], true)? + .into_par_iter() + .map(|df| { + let gene = df + .column("gene")? + .categorical()? + .iter_str() + .next() + .expect("df must be non-empty") + .expect("`gene` must not be null") + .to_owned(); + + // let gene = match df.column("gene")?.unpack()?.get(0) { + // s: String => s.clone(), + // CategoricalType(s) => s.clone(), + // }; + + let x = col_as_nonull_vec(&df, "x", |s| s.i32())?; + let y = col_as_nonull_vec(&df, "y", |s| s.i32())?; + let counts = col_as_nonull_vec(&df, "count", |s| s.u32())?; + + Ok::<_, PolarsError>(( + gene, + TriMatI::from_triplets(shape, x, y, counts).to_csr::(), + )) + }) + .collect::>()?; + + Ok((counts_dict, shape)) + } + + let n_threads = n_threads.unwrap_or(0); + let threadpool = + create_pool(n_threads).map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let resolution = resolution.map(|r| r * binsize.unwrap_or(1.)); + + match threadpool.install(|| _from_dataframe(df.into(), binsize)) { + Err(e) => Err(PyValueError::new_err(e.to_string())), + Ok((counts, shape)) => Ok(Self { + counts, + shape, + resolution, + n_threads, + threadpool, + }), + } + } + + fn __getitem__(&self, key: String) -> PyResult> { + match self.counts.get(&key) { + None => Err(PyKeyError::new_err(format!("'{key}' does not exist."))), + Some(mat) => Ok(WrappedCsx(mat.clone())), + } + } + + fn __setitem__(&mut self, key: String, value: WrappedCsx) { + self.counts.insert(key, value.0); + } + + fn __delitem__(&mut self, key: String) -> PyResult<()> { + match self.counts.remove(&key) { + None => Err(PyKeyError::new_err(key.to_string())), + Some(_) => Ok(()), + } + } + + fn __len__(&self) -> usize { + self.counts.len() + } + + fn __contains__(&self, item: String) -> bool { + self.counts.contains_key(&item) + } + + fn __eq__(&self, other: &GridCounts) -> bool { + (self.resolution == other.resolution) + && (self.shape == other.shape) + && self.counts.eq(&other.counts) + } + + fn __ne__(&self, other: &GridCounts) -> bool { + !self.__eq__(other) + } + + fn __setstate__<'py>(&mut self, state: Bound<'py, PyBytes>) -> PyResult<()> { + let (counts, shape, resolution, n_threads) = deserialize(state.as_bytes()).unwrap(); + self.counts = counts; + self.shape = shape; + self.resolution = resolution; + self.set_n_threads(n_threads); + + Ok(()) + } + + fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult> { + let to_bytes = (&self.counts, self.shape, self.resolution, self.n_threads); + + Ok(PyBytes::new_bound(py, &serialize(&to_bytes).unwrap())) + } + + fn __getnewargs_ex__( + &self, + ) -> PyResult<( + (HashMap>,), + HashMap, + )> { + Ok(((HashMap::new(),), HashMap::new())) + } + + // fn __iter__(&self) -> PyResult> { + // Python::with_gil(|py| PyList::new_bound(py, self.counts.keys()).iter()) + // } + // fn keys(&self) -> PyResult> { + // self.__iter__() + // } + + fn get( + &self, + key: String, + default: Option>, + ) -> Option> { + match self.__getitem__(key) { + Ok(x) => Some(x), + Err(_) => default, + } + } + + #[setter] + fn set_resolution(&mut self, resolution: f32) -> PyResult<()> { + if resolution > 0. { + self.resolution = Some(resolution); + Ok(()) + } else { + Err(PyValueError::new_err( + "`resolution` must be greater than zero", + )) + } + } + + #[setter] + fn set_n_threads(&mut self, n_threads: usize) -> PyResult<()> { + self.n_threads = n_threads; + self.threadpool = + create_pool(n_threads).map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(()) + } + + fn genes(&self) -> Vec { + self.counts.keys().cloned().collect() + } + + fn gene_counts(&self) -> HashMap { + self.threadpool.install(|| { + self.counts + .par_iter() + .map(|(gene, mat)| (gene.to_owned(), mat.data().iter().sum())) + .collect() + }) + } + + fn grid_counts(&self) -> Py> { + fn triplet_to_dense, I: SpIndex>( + coo: TriMatI, + ) -> Array2 { + let mut dense: Array2 = Array2::zeros(coo.shape()); + coo.triplet_iter().for_each(|(v, (i, j))| { + dense[[ + i.to_usize().expect("valid index"), + j.to_usize().expect("valid index"), + ]] += *v; + }); + + dense + } + let (v, (i, j)) = self.threadpool.install(|| { + let (v, (i, j)): (Vec<_>, (Vec<_>, Vec<_>)) = self + .counts + .par_iter() + .map(|(_, mat)| -> (Vec<_>, (Vec<_>, Vec<_>)) { mat.iter().multiunzip() }) + .unzip(); + + join(|| v.concat(), || join(|| i.concat(), || j.concat())) + }); + + let gridcounts = triplet_to_dense(TriMatI::from_triplets(self.shape, i, j, v)); + Python::with_gil(|py| gridcounts.into_pyarray_bound(py).unbind()) + } + + fn select_genes(&mut self, genes: HashSet) { + self.counts.retain(|k, _| genes.contains(k)); + } + + #[pyo3(signature = (min=1, max=Count::MAX))] + fn filter_genes_by_count(&mut self, min: Count, max: Count) { + let genes: HashSet<_> = self.threadpool.install(|| { + self.gene_counts() + .into_par_iter() + .filter_map(|(gene, count)| { + if (count >= min) & (count <= max) { + Some(gene) + } else { + None + } + }) + .collect() + }); + self.select_genes(genes); + } + + fn crop( + &mut self, + x: (Option, Option), + y: (Option, Option), + ) -> PyResult<()> { + let x_start = x.0.unwrap_or(0); + let y_start = y.0.unwrap_or(0); + let x_end = x.1.map_or(self.shape.0, |x| min(x, self.shape.0)); + let y_end = y.1.map_or(self.shape.1, |x| min(x, self.shape.1)); + + if (x_end <= x_start) || (y_end <= y_start) { + return Err(PyValueError::new_err("Trying to crop with empty slice.")); + } + + self.threadpool.install(|| { + self.counts.par_iter_mut().for_each(|(_, mat)| { + let (outer, inner) = match mat.storage() { + CSR => (x_start..x_end, y_start..y_end), + CSC => (y_start..y_end, x_start..x_end), + }; + *mat = mat + .slice_outer(outer) + .transpose_into() + .to_other_storage() + .slice_outer(inner) + .transpose_into() + .to_owned(); + }); + }); + + self.shape = (x_end - x_start, y_end - y_start); + Ok(()) + } + + fn filter_mask(&mut self, mask: PyReadonlyArray2<'_, bool>) { + let mask = mask.as_array(); + + self.threadpool.install(|| { + self.counts.par_iter_mut().for_each(|(_, mat)| { + let (data, x, y) = mat + .into_iter() + .filter(|(_, (i, j))| mask[[*i as usize, *j as usize]]) + .map(|(v, (i, j))| (v, i, j)) + .multiunzip(); + + *mat = TriMatI::from_triplets(self.shape, x, y, data).to_csr(); + }); + }); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a837473 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,26 @@ +mod coordinates; +mod cosine; +mod gridcounts; +mod sparsearray_conversion; +mod sparsekde; +mod utils; + +use coordinates::{categorical_coordinate, coordinate_as_string}; +use cosine::{cosinef32_and_celltypei16, cosinef32_and_celltypei8}; +use gridcounts::GridCounts; +use pyo3::prelude::*; +use sparsekde::{kde_at_coord, sparse_kde_csx_py}; + +/// utils written in Rust to improve performance +#[pymodule] +// #[pyo3(name = "_utils_rust")] +fn _utils_rust(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_function(wrap_pyfunction!(sparse_kde_csx_py, m)?)?; + m.add_function(wrap_pyfunction!(kde_at_coord, m)?)?; + m.add_function(wrap_pyfunction!(cosinef32_and_celltypei8, m)?)?; + m.add_function(wrap_pyfunction!(cosinef32_and_celltypei16, m)?)?; + m.add_function(wrap_pyfunction!(coordinate_as_string, m)?)?; + m.add_function(wrap_pyfunction!(categorical_coordinate, m)?)?; + Ok(()) +} diff --git a/src/sparsearray_conversion.rs b/src/sparsearray_conversion.rs new file mode 100644 index 0000000..8a94005 --- /dev/null +++ b/src/sparsearray_conversion.rs @@ -0,0 +1,117 @@ +use numpy::{Element, IntoPyArray, PyArray1, PyReadonlyArray1}; +use pyo3::{ + exceptions::{PyTypeError, PyValueError}, + prelude::*, + sync::GILOnceCell, +}; +use sprs::{ + CompressedStorage::{CSC, CSR}, + CsMatBase, CsMatI, SpIndex, +}; + +// cache scipy imports +static SP_SPARSE: GILOnceCell> = GILOnceCell::new(); +static CSR_ARRAY: GILOnceCell> = GILOnceCell::new(); +static CSC_ARRAY: GILOnceCell> = GILOnceCell::new(); +static SPARRAY: GILOnceCell> = GILOnceCell::new(); +static SPMATRIX: GILOnceCell> = GILOnceCell::new(); + +// implement WrappedCsxView? + +/// Conversion type for sprs::CsMat <-> scipy.sparse.csx_array +pub struct WrappedCsx(pub CsMatI); + +fn get_scipy_sparse(py: Python) -> PyResult<&Py> { + SP_SPARSE.get_or_try_init(py, || Ok(py.import_bound("scipy.sparse")?.unbind())) +} + +fn get_scipy_sparse_attr(py: Python, attr: &str) -> PyResult { + get_scipy_sparse(py)?.getattr(py, attr) +} + +/// Return a CsMat in SciPy CSX tuple order +pub fn make_csx_tuple( + py: Python<'_>, + cs: CsMatI, +) -> ( + Bound<'_, PyArray1>, + Bound<'_, PyArray1>, + Bound<'_, PyArray1>, +) +where + D: Element, + I: Element + SpIndex, + Iptr: Element + SpIndex, +{ + let (indptr, indices, data) = cs.into_raw_storage(); + + return ( + data.into_pyarray_bound(py), + indices.into_pyarray_bound(py), + indptr.into_pyarray_bound(py), + ); +} + +impl IntoPy + for WrappedCsx +{ + fn into_py(self, py: Python<'_>) -> PyObject { + let csx = self.0; + let shape = csx.shape(); + + let sparray = match csx.storage() { + CSR => CSR_ARRAY.get_or_try_init(py, || get_scipy_sparse_attr(py, "csr_array")), + CSC => CSC_ARRAY.get_or_try_init(py, || get_scipy_sparse_attr(py, "csc_array")), + }; + + sparray + .unwrap() + .call1(py, (make_csx_tuple(py, csx), shape)) + .unwrap() + .extract(py) + .unwrap() + } +} +impl<'py, N: Element, I: SpIndex + Element, Iptr: SpIndex + Element> FromPyObject<'py> + for WrappedCsx +{ + fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult { + fn boundpyarray_to_vec(obj: Bound<'_, PyAny>) -> PyResult> { + Ok(obj.extract::>()?.as_array().to_vec()) + } + + Python::with_gil(|py| { + let sparray = SPARRAY + .get_or_try_init(py, || get_scipy_sparse_attr(py, "sparray"))? + .bind(py); + let spmatrix = SPMATRIX + .get_or_try_init(py, || get_scipy_sparse_attr(py, "spmatrix"))? + .bind(py); + + let format = obj.getattr("format")?; + if !((obj.is_instance(spmatrix)? || obj.is_instance(sparray)?) + && ((format.eq("csr")?) || (format.eq("csc")?))) + { + Err(PyTypeError::new_err( + "Only `sparray`/`spmatrix` with format 'csr' or 'csc' can be extracted.", + )) + } else { + let shape = obj.getattr("shape")?.extract()?; + + let data = boundpyarray_to_vec(obj.getattr("data")?)?; + let indices = boundpyarray_to_vec(obj.getattr("indices")?)?; + let indptr = boundpyarray_to_vec(obj.getattr("indptr")?)?; + + let csx = if format.eq("csr")? { + CsMatBase::new_from_unsorted(shape, indptr, indices, data) + } else { + CsMatBase::new_from_unsorted_csc(shape, indptr, indices, data) + }; + match csx { + Ok(csx) => Ok(WrappedCsx(csx)), + Err((.., e)) => Err(PyValueError::new_err(e.to_string())), + } + } + }) + } +} diff --git a/src/sparsekde.rs b/src/sparsekde.rs new file mode 100644 index 0000000..d69550a --- /dev/null +++ b/src/sparsekde.rs @@ -0,0 +1,300 @@ +use crate::{ + gridcounts::{Count, CsxIndex, GridCounts}, + sparsearray_conversion::WrappedCsx, + utils::create_pool, +}; +use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, NdFloat, NewAxis, Slice, Zip}; +use num::{one, zero, NumCast, PrimInt, Signed, Zero}; +use numpy::{PyReadonlyArray1, PyReadonlyArray2}; +use pyo3::exceptions::PyValueError; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use rayon::{iter::ParallelIterator, prelude::ParallelSlice, ThreadPoolBuildError}; +use sprs::{hstack, CsMatI, CsMatViewI, SpIndex}; +use std::{ + cmp::{max, min}, + ops::Range, +}; + +type KDEPrecision = f32; + +macro_rules! build_kde_csx_fn { + ($name:tt, $t_count:ty, $t_index:ty) => { + #[pyfunction] + #[pyo3(signature = (counts, kernel, *, threshold=0.0))] + /// calculate sparse KDE + pub fn $name<'py>( + _py: Python<'py>, + counts: WrappedCsx<$t_count, $t_index, $t_index>, + kernel: PyReadonlyArray2<'py, KDEPrecision>, + threshold: KDEPrecision, + ) -> PyResult> { + let sparse_kde = sparse_kde_csx(counts.0.view(), kernel.as_array(), threshold); + Ok(WrappedCsx(sparse_kde)) + } + }; +} + +build_kde_csx_fn!(sparse_kde_csx_py, Count, CsxIndex); + +#[pyfunction] +#[pyo3(signature = (counts, genes, kernel, coordinates, *, n_threads=None))] +/// calculate KDE and retrieve coordinates +pub fn kde_at_coord<'py>( + _py: Python<'py>, + counts: &GridCounts, + genes: Vec, + kernel: PyReadonlyArray2<'py, KDEPrecision>, + coordinates: (PyReadonlyArray1<'py, isize>, PyReadonlyArray1<'py, isize>), + n_threads: Option, +) -> PyResult> { + let gene_counts: Vec<_> = genes + .iter() + .map(|g| { + counts + .get_view(g) + .ok_or(PyValueError::new_err("Not all genes exist")) + }) + .collect::>()?; + + let coordinates = (coordinates.0.as_array(), coordinates.1.as_array()); + + match kde_at_coord_( + &gene_counts, + kernel.as_array(), + coordinates, + n_threads.unwrap_or(0), + ) { + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + Ok(kde_coord) => Ok(WrappedCsx(kde_coord)), + } +} + +#[inline] +fn in_bounds_range(n: I, i: I, pad: I) -> (I, I) { + (max(-i, -pad), min(n - i, pad + one())) +} + +fn sparse_kde_csx<'a, C, I, I2, F>( + counts: CsMatViewI<'a, C, I, I2>, + kernel: ArrayView2<'a, F>, + threshold: F, +) -> CsMatI +where + C: NumCast + Copy, + I: SpIndex + Signed, + I2: SpIndex, + F: NdFloat + Signed, + Slice: From>, +{ + let mut kde = Array2::zeros(counts.shape()); + sparse_kde_csx_(&mut kde, counts, kernel); + CsMatI::csr_from_dense(ArrayView2::from(&kde), threshold) +} + +pub fn sparse_kde_csx_<'a, 'b, C, I, I2, F>( + kde: &mut Array2, + counts: CsMatViewI<'a, C, I, I2>, + kernel: ArrayView2<'b, F>, +) where + C: NumCast + Copy, + I: SpIndex + Signed, + I2: SpIndex, + F: NdFloat, + Slice: From>, + 'b: 'a, +{ + let shape = kde.shape(); + let (m, n) = (I::from(shape[0]).unwrap(), I::from(shape[1]).unwrap()); + + let shift_i = I::from((kernel.nrows() - 1) / 2).unwrap(); + let shift_j = I::from((kernel.ncols() - 1) / 2).unwrap(); + + kde.fill(zero()); + + counts.iter().for_each(|(&val, (i, j))| { + let (i_min, i_max) = in_bounds_range(m, i, shift_i); + let (j_min, j_max) = in_bounds_range(n, j, shift_j); + let val = F::from(val).unwrap(); + + Zip::from(&mut kde.slice_mut(s![ + Slice::from((i + i_min)..(i + i_max)), + Slice::from((j + j_min)..(j + j_max)) + ])) + .and(&kernel.slice(s![ + Slice::from((shift_i + i_min)..(shift_i + i_max)), + Slice::from((shift_j + j_min)..(shift_j + j_max)) + ])) + .for_each(|kde, &k| { + *kde += k * val; + }); + }) +} + +fn kde_at_coord_<'a, C, I, F, I2>( + counts: &[CsMatViewI<'a, C, I>], + kernel: ArrayView2<'a, F>, + coordinates: (ArrayView1<'a, I2>, ArrayView1<'a, I2>), + n_threads: usize, +) -> Result, ThreadPoolBuildError> +where + C: NumCast + Copy + Sync + Send, + I: SpIndex + Signed + Sync + Send, + I2: PrimInt, + F: NdFloat + Signed + Default, + Slice: From>, +{ + let pool = create_pool(n_threads)?; + let shape = counts.first().expect("At least one gene").shape(); + + let coord_x = coordinates + .0 + .mapv(|x| ::from(x).unwrap()) + .to_vec(); + let coord_y = coordinates + .1 + .mapv(|x| ::from(x).unwrap()) + .to_vec(); + + let batch_size = counts.len().div_ceil(n_threads); + + let kde_coords = pool.install(|| { + counts + .par_chunks(batch_size) + .map(|counts_batch| { + let mut kde_buffer = Array2::zeros(shape); + let mut kde_coords_batch = Vec::with_capacity(counts_batch.len()); + for &c in counts_batch { + sparse_kde_csx_(&mut kde_buffer, c, kernel); + kde_coords_batch.push(get_coord(kde_buffer.view(), (&coord_x, &coord_y))); + } + kde_coords_batch + }) + .flatten_iter() + .collect::>() + }); + + Ok(hstack( + &kde_coords.iter().map(|x| x.view()).collect::>(), + )) +} + +fn get_coord( + arr: ArrayView2<'_, T>, + coordinates: (&[usize], &[usize]), +) -> CsMatI { + let mut out = Array1::zeros(coordinates.0.len()); + + Zip::from(&mut out) + .and(coordinates.0) + .and(coordinates.1) + .for_each(|o, &i, &j| *o = arr[[i, j]]); + CsMatI::csc_from_dense(out.slice(s![.., NewAxis]), zero()) +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::array; + + struct Setup { + counts: CsMatI, + kernel: Array2, + threshold: f64, + kde: CsMatI, + } + + impl Setup { + fn new() -> Self { + Self { + counts: CsMatI::new( + (6, 12), + vec![0, 1, 2, 3, 3, 3, 4], + vec![0, 2, 3, 10], + vec![1, 1, 2, 1], + ), + kernel: array![[0.5, 0.0, 0.0], [0.5, 1.0, 0.0], [0.25, 0.0, 0.0]], + threshold: 0.4, + kde: CsMatI::new( + (6, 12), + vec![0, 2, 4, 6, 7, 8, 10], + vec![0, 1, 1, 2, 2, 3, 2, 9, 9, 10], + vec![1.0, 0.5, 0.5, 2.0, 1.0, 2.0, 0.5, 0.5, 0.5, 1.0], + ), + } + } + } + + // Input looks like this + // [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + // [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + // [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0], + // [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + // [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + // [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]] + + // Unfiltered output should look like this but after thresholding the 0.25 is removed + // [[1.0, 0.5 , 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.5 , 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.25, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.0 , 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.0 , 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0], + // [0.0, 0.0 , 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.0]] + + // let kde_test = array![ + // [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.5, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + // [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0], + // [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.0] + // ]; + + #[test] + fn test_sparse_kde_csx() { + let setup = Setup::new(); + + let kde = sparse_kde_csx(setup.counts.view(), setup.kernel.view(), setup.threshold); + + assert_eq!(kde, setup.kde); + } + + #[test] + fn test_get_coord() { + let coord_x = vec![0, 1, 1]; + let coord_y = vec![0, 2, 3]; + + let setup = Setup::new(); + + let kde_coord = get_coord(setup.kde.to_dense().view(), (&coord_x, &coord_y)); + + let result = CsMatI::new_csc((3, 1), vec![0, 2], vec![0, 1], vec![1., 2.]); + + assert_eq!(kde_coord, result) + } + #[test] + fn test_kde_at_coord() { + let coord_x = array![0, 1, 1]; + let coord_y = array![0, 2, 3]; + + let setup = Setup::new(); + + let counts = vec![setup.counts.view(), setup.counts.view()]; + + let kde_coord = kde_at_coord_( + &counts, + setup.kernel.view(), + (coord_x.view(), coord_y.view()), + 1, + ) + .unwrap(); + + let result = CsMatI::new_csc( + (3, 2), + vec![0, 2, 4], + vec![0, 1, 0, 1], + vec![1., 2., 1., 2.], + ); + + assert_eq!(kde_coord, result) + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..0847989 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,6 @@ +use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder}; + +/// Create rayon ThreadPool with n threads +pub fn create_pool(n_threads: usize) -> Result { + ThreadPoolBuilder::new().num_threads(n_threads).build() +}