diff --git a/sainsc/_utils_rust.pyi b/sainsc/_utils_rust.pyi index d6fcde0..8a05be7 100644 --- a/sainsc/_utils_rust.pyi +++ b/sainsc/_utils_rust.pyi @@ -46,7 +46,7 @@ def cosinef32_and_celltypei8( log: bool = False, chunk_size: tuple[int, int] = (500, 500), n_threads: int | None = None, -) -> tuple[NDArray[np.float32], NDArray[np.int8]]: +) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.int8]]: """ Calculate the cosine similarity given counts and signatures and assign the most similar celltype. @@ -61,7 +61,7 @@ def cosinef32_and_celltypei16( log: bool = False, chunk_size: tuple[int, int] = (500, 500), n_threads: int | None = None, -) -> tuple[NDArray[np.float32], NDArray[np.int16]]: +) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.int16]]: """ Calculate the cosine similarity given counts and signatures and assign the most similar celltype. @@ -87,7 +87,7 @@ class GridCounts: """ Parameters ---------- - counts : dict[str, scipy.sparse.csr_array | scipy.sparse.csr_matrix| scipy.sparse.csc_array| scipy.sparse.csc_matrix] + counts : dict[str, scipy.sparse.csr_array | scipy.sparse.csr_matrix | scipy.sparse.csc_array | scipy.sparse.csc_matrix] Gene counts. resolution : float, optional Resolution as nm / pixel. diff --git a/sainsc/lazykde/_LazyKDE.py b/sainsc/lazykde/_LazyKDE.py index 9a96319..0458637 100644 --- a/sainsc/lazykde/_LazyKDE.py +++ b/sainsc/lazykde/_LazyKDE.py @@ -1,15 +1,17 @@ from collections.abc import Iterable +from itertools import chain from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt import numpy as np import pandas as pd +import polars as pl import seaborn as sns from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import to_rgb from matplotlib.figure import Figure -from matplotlib.lines import Line2D +from matplotlib.patches import Patch from matplotlib_scalebar.scalebar import ScaleBar from mpl_toolkits import axes_grid1 from numba import njit @@ -78,39 +80,100 @@ def __init__( self._threads = n_threads + self._kernel: NDArray[np.float32] | None = None self._total_mRNA: NDArray[np.unsignedinteger] | None = None self._total_mRNA_KDE: NDArray[np.float32] | None = None self._background: NDArray[np.bool_] | None = None self._local_maxima: _Local_Max | None = None self._celltype_map: NDArray[np.signedinteger] | None = None self._cosine_similarity: NDArray[np.float32] | None = None + self._assignment_score: NDArray[np.float32] | None = None self._celltypes: list[str] | None = None + @classmethod + def from_dataframe( + cls, df: pl.DataFrame | pd.DataFrame, *, n_threads: int | None = None, **kwargs + ): + """ + Construct a LazyKDE from a DataFrame. + + The DataFrame must provide a 'gene', 'x', and 'y' column. If a 'count' column + exists it will be used as counts else a count of 1 (single molecule) per row + will be assumed. + + Parameters + ---------- + df : polars.DataFrame | pandas.DataFrame + n_threads : int, optional + Number of threads used for reading and processing file. If `None` this will + default to the number of available CPUs. + kwargs + Other keyword arguments are passed to + :py:meth:`sainsc.GridCounts.from_dataframe`. + """ + if isinstance(df, pd.DataFrame): + df = pl.from_pandas(df) + + # TODO ensure dataframe format + count_col = ["count"] if "count" in df.columns else [] + + df = df.select( + pl.col("gene").cast(pl.Categorical), pl.col(["x", "y"] + count_col) + ) + return cls( + GridCounts.from_dataframe(df, n_threads=n_threads, **kwargs), + n_threads=n_threads, + ) + ## Kernel def gaussian_kernel( - self, sigma: float, *, truncate: float = 2, circular: bool = False + self, + bw: float, + *, + unit: str = "px", + truncate: float = 2, + circular: bool = False, ): """ Set the kernel used for kernel density estimation (KDE) to gaussian. Parameters ---------- - sigma : float + bw : float Bandwidth of the kernel. + unit : str + Which unit the bandwidth of the kernel is defined in: 'px' or 'um'. + 'um' requires :py:attr:`sainsc.LazyKDE.resolution` to be set correctly. truncate : float, optional - The radius for calculating the KDE is calculated as `sigma` * `truncate`. + The radius for calculating the KDE is calculated as `bw` * `truncate`. Refer to :py:func:`scipy.ndimage.gaussian_filter`. circular : bool, optional If `True` calculate the KDE using a circular kernel instead of square by - setting all values outside the radius `sigma` * `truncate` to 0. + setting all values outside the radius `bw` * `truncate` to 0. + + Raises + ------ + ValueError + If `unit` is neither 'px' nor 'um'. + ValueError + If `unit` is 'um' but `resolution` is not set. See Also -------- :py:meth:`sainsc.LazyKDE.kde` """ + + if unit == "um": + if self.resolution is None: + raise ValueError( + "Using `unit`='um' requires the `resolution` to be set." + ) + bw /= self.resolution / 1_000 + elif unit != "px": + raise ValueError("`unit` must be either 'px' or 'um'") dtype = np.float32 - radius = round(truncate * sigma) - self.kernel = gaussian_kernel(sigma, radius, dtype=dtype, circular=circular) + radius = round(truncate * bw) + self.kernel = gaussian_kernel(bw, radius, dtype=dtype, circular=circular) ## KDE def kde(self, gene: str, *, threshold: float | None = None) -> _CsxArray: @@ -143,6 +206,9 @@ def _kde(self, arr: NDArray | _Csx, threshold: float | None = None) -> _CsxArray if isinstance(arr, np.ndarray): arr = csr_array(arr) + if self.kernel is None: + raise ValueError("`kernel` must be set before running KDE") + if arr.dtype == np.uint32: return sparse_kde_csx_py(arr, self.kernel, threshold=threshold) else: @@ -166,6 +232,11 @@ def calculate_total_mRNA_KDE(self): If :py:attr:`sainsc.LazyKDE.total_mRNA` has not been calculated :py:meth:`sainsc.LazyKDE.calculate_total_mRNA` is run first. + Raises + ------ + ValueError + If `self.kernel` is not set. + See Also -------- :py:meth:`sainsc.LazyKDE.gaussian_kernel` @@ -242,6 +313,8 @@ def load_local_maxima( ------ ModuleNotFoundError If `spatialdata` is set to `True` but the package is not installed. + ValueError + If `self.kernel` is not set. """ if self.local_maxima is None: raise ValueError("`local_maxima` have to be identified before loading") @@ -333,6 +406,9 @@ def load_local_maxima( def _load_KDE_maxima(self, genes: list[str]) -> csc_array | csr_array: assert self.local_maxima is not None + if self.kernel is None: + raise ValueError("`kernel` must be set before running KDE") + return kde_at_coord( self.counts, genes, self.kernel, self.local_maxima, n_threads=self.n_threads ) @@ -342,36 +418,47 @@ def filter_background( self, min_norm: float | dict[str, float], min_cosine: float | dict[str, float] | None = None, + min_assignment: float | dict[str, float] | None = None, ): """ - Assign beads as background. + Define pixels as background. + + If using multiple thresholds (e.g. on norm and cosine similarity) they will be + combined and pixels are defined as background if they are lower than any of the + thresholds. Parameters ---------- min_norm : float or dict[str, float] The threshold for defining background based on :py:attr:`sainsc.LazyKDE.total_mRNA_KDE`. - Either a float which is used as global threshold or a mapping from cell-types + Either a float which is used as global threshold or a mapping from cell types to thresholds. Cell-type assignment is needed for cell type-specific thresholds. min_cosine : float or dict[str, float], optional - The threshold for defining background based on the minimum cosine - similarity. Cell type-specific thresholds can be defined as for `min_norm`. + The threshold for defining background based on + :py:attr:`sainsc.LazyKDE.cosine_similarity`. Cell type-specific thresholds + can be defined as for `min_norm`. + min_assignment : float or dict[str, float], optional + The threshold for defining background based on + :py:attr:`sainsc.LazyKDE.assignment_score`. Cell type-specific thresholds + can be defined as for `min_norm`. Raises ------ ValueError - If cell type-specific thresholds do not include all cell-types. + If cell type-specific thresholds do not include all cell types or if + using cell type-specific thresholds before cell type assignment. """ @njit def _map_celltype_to_value( - ct_map: NDArray[np.integer], dict: dict[int, float] + ct_map: NDArray[np.integer], thresholds: tuple[float, ...] ) -> NDArray[np.floating]: values = np.zeros(shape=ct_map.shape, dtype=float) for i in range(ct_map.shape[0]): for j in range(ct_map.shape[1]): if ct_map[i, j] >= 0: - values[i, j] = dict[ct_map[i, j]] + values[i, j] = thresholds[ct_map[i, j]] return values if self.total_mRNA_KDE is None: @@ -386,7 +473,7 @@ def _map_celltype_to_value( ) elif not all([ct in min_norm.keys() for ct in self.celltypes]): raise ValueError("'min_norm' does not contain all celltypes.") - idx2threshold = {idx: min_norm[ct] for idx, ct in enumerate(self.celltypes)} + idx2threshold = tuple(min_norm[ct] for ct in self.celltypes) threshold = _map_celltype_to_value(self.celltype_map, idx2threshold) background = self.total_mRNA_KDE < threshold else: @@ -404,13 +491,29 @@ def _map_celltype_to_value( ) elif not all([ct in min_cosine.keys() for ct in self.celltypes]): raise ValueError("'min_cosine' does not contain all celltypes.") - idx2threshold = { - idx: min_cosine[ct] for idx, ct in enumerate(self.celltypes) - } + idx2threshold = tuple(min_cosine[ct] for ct in self.celltypes) threshold = _map_celltype_to_value(self.celltype_map, idx2threshold) - background &= self.cosine_similarity >= threshold + background |= self.cosine_similarity <= threshold else: - background &= self.cosine_similarity >= min_cosine + background |= self.cosine_similarity <= min_cosine + + if min_assignment is not None: + if self.assignment_score is None: + raise ValueError( + "Assignment score threshold can only be used after cell-type assignment" + ) + if isinstance(min_assignment, dict): + if self.celltypes is None or self.celltype_map is None: + raise ValueError( + "Cell type-specific threshold can only be used after cell-type assignment" + ) + elif not all([ct in min_assignment.keys() for ct in self.celltypes]): + raise ValueError("'min_assignment' does not contain all celltypes.") + idx2threshold = tuple(min_assignment[ct] for ct in self.celltypes) + threshold = _map_celltype_to_value(self.celltype_map, idx2threshold) + background |= self.assignment_score <= threshold + else: + background |= self.assignment_score <= min_assignment self._background = background @@ -446,6 +549,15 @@ def assign_celltype( chunk : tuple[int, int] Size of the chunks for processing. Larger chunks require more memory but have less duplicated computation. + + Raises + ------ + ValueError + If not all genes of the `signatures` are available. + ValueError + If `self.kernel` is not set. + ValueError + If `chunk` is smaller than the shape of `self.kernel`. """ if not all(signatures.index.isin(self.genes)): @@ -453,6 +565,9 @@ def assign_celltype( "Not all genes in the gene signature are part of this KDE." ) + if self.kernel is None: + raise ValueError("`kernel` must be set before running KDE") + if not all(s < c for s, c in zip(self.kernel.shape, chunk)): raise ValueError("`chunk` must be larger than shape of kernel.") @@ -471,7 +586,7 @@ def assign_celltype( fn = self._calculate_cosine_celltype_fn(ct_dtype) - self._cosine_similarity, self._celltype_map = fn( + self._cosine_similarity, self._assignment_score, self._celltype_map = fn( self.counts, genes, signatures_mat, @@ -496,6 +611,7 @@ def _plot_2d( ) -> Figure: if remove_background: if self.background is not None: + img = img.copy() img[self.background] = 0 else: raise ValueError("`background` is undefined") @@ -781,8 +897,11 @@ def plot_celltypemap( crop: _RangeTuple2D | None = None, scalebar: bool = True, cmap: _Cmap = "hls", + background: str | tuple = "black", + undefined: str | tuple = "grey", scalebar_kwargs: dict = _SCALEBAR, - ) -> Figure: + return_img: bool = False, + ) -> Figure | NDArray[np.uint8]: """ Plot the cell-type annotation. @@ -801,15 +920,21 @@ def plot_celltypemap( If it is a list of colors it must have the same length as the number of celltypes. If it is a dictionary it must be a mapping from celltpye to color. Undefined - celltypes are plotted as `'grey'`. + celltypes are plotted according to `undefined`. Colors can either be provided as string that can be converted via :py:func:`matplotlib.colors.to_rgb` or as ``(r, g, b)``-tuple between 0-1. + background : str | tuple[float, float, float] + Color for the background. + undefined : str | tuple[float, float, float] + Color used for celltypes without a defined color. scalebar_kwargs : dict[str, typing.Any], optional Keyword arguments that are passed to ``matplotlib_scalebar.scalebar.ScaleBar``. + return_img : bool, optional + Return the cell-type map as 3D-array (x, y, RGB) instead of the Figure. Returns ------- - matplotlib.figure.Figure + matplotlib.figure.Figure | numpy.ndarray[numpy.uint8] See Also @@ -821,7 +946,7 @@ def plot_celltypemap( n_celltypes = len(self.celltypes) - celltype_map = self.celltype_map + celltype_map = self.celltype_map.copy() if remove_background: if self.background is None: raise ValueError("Background has not been filtered.") @@ -831,6 +956,9 @@ def plot_celltypemap( if crop is not None: celltype_map = celltype_map[tuple(slice(*c) for c in crop)] + # shift so 0 will be background + celltype_map += 1 + if isinstance(cmap, str): color_map = sns.color_palette(cmap, n_colors=n_celltypes) assert isinstance(color_map, Iterable) @@ -840,27 +968,22 @@ def plot_celltypemap( raise ValueError("You need to provide 1 color per celltype") elif isinstance(cmap, dict): - cmap = [cmap.get(cell, "grey") for cell in self.celltypes] + cmap = [cmap.get(cell, undefined) for cell in self.celltypes] color_map = [to_rgb(c) if isinstance(c, str) else c for c in cmap] # convert to uint8 to reduce memory of final image color_map_int = tuple( - (np.array(c) * 255).round().astype(np.uint8) for c in color_map + (np.array(c) * 255).round().astype(np.uint8) + for c in chain([to_rgb(background)], color_map) ) img = _apply_color(celltype_map.T, color_map_int) + if return_img: + return img + legend_elements = [ - Line2D( - [0], - [0], - marker="o", - color="w", - label=lbl, - markerfacecolor=c, - markersize=10, - ) - for c, lbl in zip(color_map, self.celltypes) + Patch(color=c, label=lbl) for c, lbl in zip(color_map, self.celltypes) ] fig, ax = plt.subplots() @@ -928,6 +1051,55 @@ def plot_cosine_similarity( scalebar_kwargs=scalebar_kwargs, ) + def plot_assignment_score( + self, + *, + remove_background: bool = False, + crop: _RangeTuple2D | None = None, + scalebar: bool = True, + im_kwargs: dict = dict(), + scalebar_kwargs: dict = _SCALEBAR, + ) -> Figure: + """ + Plot the assignment score from cell-type assignment. + + Parameters + ---------- + remove_background : bool, optional + If `True`, all pixels for which :py:attr:`sainsc.LazyKDE.background` is + `False` are set to 0. + crop : tuple[tuple[int, int], tuple[int, int]], optional + Coordinates to crop the data defined as `((xmin, xmax), (ymin, ymax))`. + scalebar : bool, optional + If `True`, add a ``matplotlib_scalebar.scalebar.ScaleBar`` to the plot. + im_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to :py:func:`matplotlib.pyplot.imshow`. + scalebar_kwargs : dict[str, typing.Any], optional + Keyword arguments that are passed to ``matplotlib_scalebar.scalebar.ScaleBar``. + + Returns + ------- + matplotlib.figure.Figure + + See Also + -------- + :py:meth:`sainsc.LazyKDE.assign_celltype` + """ + if self.assignment_score is not None: + img = self.assignment_score + else: + raise ValueError("Cell types have not been assigned") + + return self._plot_2d( + img, + "Assignment score", + remove_background=remove_background, + crop=crop, + scalebar=scalebar, + im_kwargs=im_kwargs, + scalebar_kwargs=scalebar_kwargs, + ) + ## Attributes @property def n_threads(self) -> int: @@ -982,7 +1154,7 @@ def resolution(self, resolution: float): self.counts.resolution = resolution @property - def kernel(self) -> np.ndarray: + def kernel(self) -> np.ndarray | None: """ numpy.ndarray: Map of the KDE of total mRNA. @@ -991,7 +1163,7 @@ def kernel(self) -> np.ndarray: ValueError If kernel is not a square, 2D :py:class:`numpy.ndarray` of uneven length. """ - return self._kernel.copy() + return self._kernel @kernel.setter def kernel(self, kernel: np.ndarray): @@ -1028,7 +1200,7 @@ def total_mRNA_KDE(self) -> NDArray[np.single] | None: return self._total_mRNA_KDE @property - def background(self) -> NDArray[np.bool] | None: + def background(self) -> NDArray[np.bool_] | None: """ numpy.ndarray[numpy.bool]: Map of pixels that are assigned as background. @@ -1042,7 +1214,7 @@ def background(self) -> NDArray[np.bool] | None: return self._background @background.setter - def background(self, background: NDArray[np.bool]): + def background(self, background: NDArray[np.bool_]): if background.shape != self.shape: raise ValueError("`background` must have same shape as `self`") else: @@ -1062,6 +1234,18 @@ def cosine_similarity(self) -> NDArray[np.single] | None: """ return self._cosine_similarity + @property + def assignment_score(self) -> NDArray[np.single] | None: + """ + numpy.ndarray[numpy.single]: Assignment score for each pixel. + + Let `x` be the gene expression of a pixel, and `i` and `j` the signatures of the + best and 2nd best scoring cell type, respectively. The assignment score is + calculated as :math:`\\frac{cos(\\theta_{xi}) - cos(\\theta_{xj})}{cos(\\pi/2 - \\theta_{ij})}` + where :math:`\\theta` is the angle between the corresponding vectors. + """ + return self._assignment_score + @property def celltype_map(self) -> NDArray[np.signedinteger] | None: """ @@ -1080,6 +1264,8 @@ def __str__(self) -> str: ] if self.resolution is not None: repr.append(f"resolution: {self.resolution} nm / px") + if self.kernel is not None: + repr.append(f"kernel: {self.kernel.shape}") if self.background is not None: repr.append("background: set") if self.local_maxima is not None: diff --git a/sainsc/lazykde/_kernel.py b/sainsc/lazykde/_kernel.py index c3686a8..28d613e 100644 --- a/sainsc/lazykde/_kernel.py +++ b/sainsc/lazykde/_kernel.py @@ -21,7 +21,7 @@ def _make_circular_kernel(kernel: NDArray[T], radius: int) -> NDArray[T]: def gaussian_kernel( - sigma: float, + bw: float, radius: int, *, dtype: DTypeLike = np.float32, @@ -33,7 +33,7 @@ def gaussian_kernel( Parameters ---------- - sigma : float + bw : float Bandwidth of the Gaussian. radius : int Radius of the kernel. Output size will be :math:`2*radius+1`. @@ -54,7 +54,7 @@ def gaussian_kernel( dirac = signal.unit_impulse((mask_size, mask_size), idx="mid") gaussian_kernel = ndimage.gaussian_filter( - dirac, sigma, output=np.float64, **kwargs + dirac, bw, output=np.float64, **kwargs ).astype(dtype) if circular: @@ -63,7 +63,7 @@ def gaussian_kernel( return gaussian_kernel -def epanechnikov_kernel(sigma: float, *, dtype: DTypeLike = np.float32) -> np.ndarray: +def epanechnikov_kernel(bw: float, *, dtype: DTypeLike = np.float32) -> np.ndarray: """ Generate a 2D Epanechnikov kernel array. @@ -73,7 +73,7 @@ def epanechnikov_kernel(sigma: float, *, dtype: DTypeLike = np.float32) -> np.nd Parameters ---------- - sigma : float + bw : float Bandwidth of the kernel. dtype : numpy.typing.DTypeLike, optional Datatype of the kernel. @@ -85,7 +85,7 @@ def epanechnikov_kernel(sigma: float, *, dtype: DTypeLike = np.float32) -> np.nd # https://doi.org/10.1109/CVPR.2000.854761 # c_d = pi for d=2 - r = math.ceil(sigma) + r = math.ceil(bw) dia = 2 * r - 1 # values at r are zero anyways so the kernel matrix can be smaller # 1/2 * pi^-1 * (d+2) @@ -96,8 +96,8 @@ def epanechnikov_kernel(sigma: float, *, dtype: DTypeLike = np.float32) -> np.nd for j in range(dia): x = i - r + 1 y = j - r + 1 - norm = (x / sigma) ** 2 + (y / sigma) ** 2 + norm = (x / bw) ** 2 + (y / bw) ** 2 if norm < 1: kernel[i, j] = scale * (1 - norm) - return kernel + return kernel / np.sum(kernel) diff --git a/sainsc/lazykde/_utils.py b/sainsc/lazykde/_utils.py index 3a8723f..73531f1 100644 --- a/sainsc/lazykde/_utils.py +++ b/sainsc/lazykde/_utils.py @@ -23,11 +23,10 @@ def _apply_color( img_in: NDArray[np.integer], cmap: tuple[NDArray[T], ...] ) -> NDArray[T]: - img = np.zeros(shape=(*img_in.shape, 3), dtype=cmap[0].dtype) + img = np.empty(shape=(*img_in.shape, 3), dtype=cmap[0].dtype) for i in range(img_in.shape[0]): for j in range(img_in.shape[1]): - if img_in[i, j] >= 0: - img[i, j, :] = cmap[img_in[i, j]] + img[i, j, :] = cmap[img_in[i, j]] return img @@ -78,4 +77,4 @@ def __call__( log: bool = ..., chunk_size: tuple[int, int] = ..., n_threads: int | None = ..., - ) -> tuple[NDArray[np.float32], NDArray[np.signedinteger]]: ... + ) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.signedinteger]]: ... diff --git a/src/cosine.rs b/src/cosine.rs index d1e3e01..87ec9b7 100644 --- a/src/cosine.rs +++ b/src/cosine.rs @@ -1,12 +1,12 @@ use crate::gridcounts::GridCounts; use crate::sparsekde::sparse_kde_csx_; use crate::utils::create_pool; + use ndarray::{ - concatenate, s, Array, Array2, Array3, ArrayView2, Axis, NdFloat, NewAxis, ShapeError, Slice, - Zip, + concatenate, s, Array2, Array3, ArrayView1, ArrayView2, Axis, NdFloat, NewAxis, ShapeError, + Slice, Zip, }; -use ndarray_stats::QuantileExt; -use num::{one, zero, NumCast, PrimInt, Signed}; +use num::{one, zero, NumCast, PrimInt, Signed, Zero}; use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2}; use pyo3::{exceptions::PyValueError, prelude::*}; use rayon::prelude::*; @@ -14,7 +14,7 @@ use sprs::{CompressedStorage::CSR, CsMatI, CsMatViewI, SpIndex}; use std::{ cmp::{max, min}, error::Error, - ops::Range, + ops::{Range, Sub}, }; macro_rules! build_cos_ct_fn { @@ -31,7 +31,7 @@ macro_rules! build_cos_ct_fn { log: bool, chunk_size: (usize, usize), n_threads: Option, - ) -> PyResult<(Bound<'py, PyArray2<$t_cos>>, Bound<'py, PyArray2<$t_ct>>)> { + ) -> PyResult<(Bound<'py, PyArray2<$t_cos>>, Bound<'py, PyArray2<$t_cos>>, Bound<'py, PyArray2<$t_ct>>)> { // ensure that all count arrays are CSR counts.to_format(CSR); @@ -55,8 +55,9 @@ macro_rules! build_cos_ct_fn { ); match cos_ct { - Ok((cosine, celltype_map)) => Ok(( + Ok((cosine, score, celltype_map)) => Ok(( cosine.into_pyarray_bound(py), + score.into_pyarray_bound(py), celltype_map.into_pyarray_bound(py), )), Err(e) => Err(PyValueError::new_err(e.to_string())), @@ -76,7 +77,7 @@ fn chunk_and_calculate_cosine<'a, C, I, F, U>( log: bool, chunk_size: (usize, usize), n_threads: usize, -) -> Result<(Array2, Array2), Box> +) -> Result<(Array2, Array2, Array2), Box> where C: NumCast + Copy + Sync + Send + Default, I: SpIndex + Signed + Sync + Send, @@ -93,7 +94,29 @@ where let (padrow, padcol) = ((kernelsize[0] - 1) / 2, (kernelsize[1] - 1) / 2); let (m, n) = (nrow.div_ceil(srow), ncol.div_ceil(scol)); // number of chunks + let n_celltype = signatures.ncols(); + let signature_similarity_correction = + Array2::from_shape_fn((n_celltype, n_celltype), |(i, j)| { + if i != j { + let sig1 = signatures.index_axis(Axis(1), i); + let sig2 = signatures.index_axis(Axis(1), j); + // technically we want the dot_product of s=(sig1-sig2) with a vector where + // the negative dimensions of this vector are set to zero (x), + // but these will then cancel out anyway so we can simplify to using the + // dot product with itself s . x => x . x + // additional we need to divide by the norm of x + // as the norm is the sqrt of the dot product with itself (which we + // already calculated) divided by its sqrt we end up with + // s . x / norm(x) = x . x / sqrt(x . x) = sqrt(x . x) + let x = (&sig1 - &sig2).mapv(|x| if x <= zero() { zero() } else { x }); + x.dot(&x).sqrt() + } else { + zero() + } + }); + let mut cosine_rows = Vec::with_capacity(m); + let mut score_rows = Vec::with_capacity(m); let mut celltype_rows = Vec::with_capacity(m); pool.install(|| { @@ -108,7 +131,10 @@ where }) .collect(); - let (cosine_cols, celltype_cols): (Vec>, Vec>) = (0..n) + let ((cosine_cols, score_cols), celltype_cols): ( + (Vec>, Vec>), + Vec>, + ) = (0..n) .into_par_iter() .map(|j| { let (slice_col, unpad_col) = chunk_(j, scol, ncol, padcol); @@ -121,6 +147,7 @@ where cosine_and_celltype_( chunk, signatures, + signature_similarity_correction.view(), kernel, (unpad_row.clone(), unpad_col), log, @@ -128,13 +155,15 @@ where }) .unzip(); cosine_rows.push(concat_1d(cosine_cols, 1)); + score_rows.push(concat_1d(score_cols, 1)); celltype_rows.push(concat_1d(celltype_cols, 1)); } }); let cosine = concat_1d(cosine_rows.into_iter().collect::, _>>()?, 0)?; + let score = concat_1d(score_rows.into_iter().collect::, _>>()?, 0)?; let celltype = concat_1d(celltype_rows.into_iter().collect::, _>>()?, 0)?; - Ok((cosine, celltype)) + Ok((cosine, score, celltype)) } fn concat_1d( @@ -160,10 +189,11 @@ fn chunk_(i: usize, step: usize, n: usize, pad: usize) -> (Range, Range( counts: Vec>, signatures: ArrayView2<'a, F>, + pairwise_correction: ArrayView2, kernel: ArrayView2<'a, F>, unpad: (Range, Range), log: bool, -) -> (Array2, Array2) +) -> ((Array2, Array2), Array2) where C: NumCast + Copy, F: NdFloat, @@ -181,7 +211,10 @@ where // fastpath if all csx are empty None => { let shape = (unpad_r.end - unpad_r.start, unpad_c.end - unpad_c.start); - (Array2::zeros(shape), Array2::from_elem(shape, -one::())) + ( + (Array2::zeros(shape), Array2::zeros(shape)), + Array2::from_elem(shape, -one::()), + ) } Some((csx, weights)) => { let shape = csx.shape(); @@ -215,7 +248,7 @@ where .for_each(|(mut cos, &w)| cos += &kde_unpadded.map(|&x| x * w)); } // TODO: write to zarr - get_max_cosine_and_celltype(cosine, kde_norm) + get_max_cosine_and_celltype(cosine, kde_norm, pairwise_correction) } } } @@ -223,34 +256,58 @@ where fn get_max_cosine_and_celltype( cosine: Array3, kde_norm: Array2, -) -> (Array2, Array2) + pairwise_correction: ArrayView2, +) -> ((Array2, Array2), Array2) where I: PrimInt + Signed, F: NdFloat, { - let (mut max_cosine, mut celltypemap) = get_max_argmax(&cosine); + let vars = cosine.map_axis(Axis(0), |view| get_argmax2(view)); + let mut max_cosine = vars.mapv(|(c, _, _, _)| c); + let mut score = vars.mapv(|(_, s, _, _)| s); + let mut celltypemap = vars.mapv(|(_, _, i, _)| I::from(i).unwrap()); Zip::from(&mut celltypemap) .and(&mut max_cosine) + .and(&mut score) + .and(&vars) .and(&kde_norm) - .for_each(|ct, cos, &norm| { + .for_each(|ct, cos, s, (_, _, i, j), &norm| { if norm == zero() { *ct = -one::(); } else { - *cos /= norm.sqrt(); - } + let norm_sqrt = norm.sqrt(); + *cos /= norm_sqrt; + *s /= norm_sqrt * pairwise_correction[[*i, *j]]; + }; }); - (max_cosine, celltypemap) + ((max_cosine, score), celltypemap) } -pub fn get_max_argmax( - array: &Array3, -) -> (Array2, Array2) { - let argmax = array.map_axis(Axis(0), |view| view.argmax().unwrap()); - let max = Array::from_shape_fn(argmax.raw_dim(), |(i, j)| array[[argmax[[i, j]], i, j]]); - // let max = array.map_axis(Axis(0), |view| view.max().unwrap().clone()); - (max, argmax.mapv(|i| I::from(i).unwrap())) +fn get_argmax2<'a, T: Zero + PartialOrd + Copy + Sub>( + values: ArrayView1<'a, T>, +) -> (T, T, usize, usize) { + let mut max = zero(); + let mut max2 = zero(); + + let mut argmax = 0; + let mut argmax2 = 0; + + for (i, &val) in values.indexed_iter() { + if val > max2 { + if val > max { + max2 = max; + max = val; + argmax2 = argmax; + argmax = i; + } else { + max2 = val; + argmax2 = i; + } + } + } + (max, max - max2, argmax, argmax2) } #[cfg(test)] @@ -281,24 +338,24 @@ mod tests { } } - #[test] - fn test_max_argmax() { - let setup = Setup::new(); + // #[test] + // fn test_max_argmax() { + // let setup = Setup::new(); - let max_argmax: (Array2, Array2) = get_max_argmax(&setup.cosine); + // let max_argmax: (Array2, Array2) = get_max_argmax(&setup.cosine); - assert_eq!(max_argmax.0, setup.max); - assert_eq!(max_argmax.1, setup.argmax); - } + // assert_eq!(max_argmax.0, setup.max); + // assert_eq!(max_argmax.1, setup.argmax); + // } - #[test] - fn test_get_max_cosine_and_celltype() { - let setup = Setup::new(); + // #[test] + // fn test_get_max_cosine_and_celltype() { + // let setup = Setup::new(); - let cos_ct: (Array2, Array2) = - get_max_cosine_and_celltype(setup.cosine, setup.norm); + // let cos_ct: ((Array2, Array2), Array2) = + // get_max_cosine_and_celltype(setup.cosine, setup.norm); - assert_eq!(cos_ct.0, setup.cos); - assert_eq!(cos_ct.1, setup.celltype); - } + // assert_eq!(cos_ct.0 .0, setup.cos); + // assert_eq!(cos_ct.1, setup.celltype); + // } } diff --git a/src/gridcounts.rs b/src/gridcounts.rs index c1f64d3..a51e630 100644 --- a/src/gridcounts.rs +++ b/src/gridcounts.rs @@ -254,7 +254,7 @@ impl GridCounts { self.counts = counts; self.shape = shape; self.resolution = resolution; - self.set_n_threads(n_threads); + self.set_n_threads(n_threads)?; Ok(()) } @@ -274,12 +274,14 @@ impl GridCounts { Ok(((HashMap::new(),), HashMap::new())) } - // fn __iter__(&self) -> PyResult> { - // Python::with_gil(|py| PyList::new_bound(py, self.counts.keys()).iter()) - // } - // fn keys(&self) -> PyResult> { - // self.__iter__() - // } + fn __str__(&self) -> String { + let repr = vec![ + format!("GridCounts ({} threads)", self.n_threads), + format!("genes: {}", self.counts.len()), + format!("shape: {:?}", self.shape), + ]; + return repr.join("\n "); + } fn get( &self,