diff --git a/README.md b/README.md index 2a1c47ca..8f318d6c 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,6 @@ If you have used stLearn in your research, please consider citing us: -> Pham, Duy, et al. "Robust mapping of spatiotemporal trajectories and cell–cell interactions in healthy and diseased tissues." +> Pham, Duy, et al. "Robust mapping of spatiotemporal trajectories and cell–cell interactions in healthy and diseased tissues." > Nature Communications 14.1 (2023): 7739. > [https://doi.org/10.1101/2020.05.31.125658](https://doi.org/10.1038/s41467-023-43120-6) diff --git a/requirements.txt b/requirements.txt index a3c68fc3..c5452f88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ bokeh>= 2.4.2 click>=8.0.4 leidenalg louvain +numba<=0.57.1 numpy>=1.18,<1.22 Pillow>=9.0.1 scanpy>=1.8.2 diff --git a/stlearn/adds/annotation.py b/stlearn/adds/annotation.py index ca509382..a8bc1ac9 100644 --- a/stlearn/adds/annotation.py +++ b/stlearn/adds/annotation.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, List from anndata import AnnData from matplotlib import pyplot as plt from pathlib import Path @@ -7,11 +7,10 @@ def annotation( adata: AnnData, - label_list: list, + label_list: List[str], use_label: str = "louvain", copy: bool = False, ) -> Optional[AnnData]: - """\ Adding annotation for cluster @@ -38,8 +37,9 @@ def annotation( if len(label_list) != len(adata.obs[use_label].unique()): raise ValueError("Please give the correct number of label list!") - adata.obs[use_label + "_anno"] = adata.obs[use_label] - adata.obs[use_label + "_anno"].cat.categories = label_list + adata.obs[use_label + "_anno"] = adata.obs[use_label].cat.rename_categories( + label_list + ) print("The annotation is added to adata.obs['" + use_label + "_anno" + "']") diff --git a/stlearn/image_preprocessing/image_tiling.py b/stlearn/image_preprocessing/image_tiling.py index e0cd6712..bdb88a60 100644 --- a/stlearn/image_preprocessing/image_tiling.py +++ b/stlearn/image_preprocessing/image_tiling.py @@ -13,9 +13,10 @@ def tiling( adata: AnnData, out_path: Union[Path, str] = "./tiling", - library_id: str = None, + library_id: Union[str, None] = None, crop_size: int = 40, target_size: int = 299, + img_fmt: str = "JPEG", verbose: bool = False, copy: bool = False, ) -> Optional[AnnData]: @@ -78,17 +79,24 @@ def tiling( (imagecol_left, imagerow_down, imagecol_right, imagerow_up) ) tile.thumbnail((target_size, target_size), Image.Resampling.LANCZOS) - tile.resize((target_size, target_size)) + tile = tile.resize((target_size, target_size)) tile_name = str(imagecol) + "-" + str(imagerow) + "-" + str(crop_size) - out_tile = Path(out_path) / (tile_name + ".jpeg") - tile_names.append(str(out_tile)) + + if img_fmt == "JPEG": + out_tile = Path(out_path) / (tile_name + ".jpeg") + tile_names.append(str(out_tile)) + tile.save(out_tile, "JPEG") + else: + out_tile = Path(out_path) / (tile_name + ".png") + tile_names.append(str(out_tile)) + tile.save(out_tile, "PNG") + if verbose: print( "generate tile at location ({}, {})".format( str(imagecol), str(imagerow) ) ) - tile.save(out_tile, "JPEG") pbar.update(1) diff --git a/stlearn/plotting/trajectory/__init__.py b/stlearn/plotting/trajectory/__init__.py index 6638c77c..16681a51 100644 --- a/stlearn/plotting/trajectory/__init__.py +++ b/stlearn/plotting/trajectory/__init__.py @@ -1,5 +1,6 @@ from .pseudotime_plot import pseudotime_plot from .local_plot import local_plot +from .tree_plot_simple import tree_plot_simple from .tree_plot import tree_plot from .transition_markers_plot import transition_markers_plot from .DE_transition_plot import DE_transition_plot diff --git a/stlearn/plotting/trajectory/tree_plot.py b/stlearn/plotting/trajectory/tree_plot.py index 3fd33bae..90ade45f 100644 --- a/stlearn/plotting/trajectory/tree_plot.py +++ b/stlearn/plotting/trajectory/tree_plot.py @@ -33,7 +33,6 @@ def tree_plot( ncols: int = 4, copy: bool = False, ) -> Optional[AnnData]: - """\ Hierarchical tree plot represent for the global spatial trajectory inference. @@ -114,7 +113,6 @@ def tree_plot( def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5): - """ From Joel's answer at https://stackoverflow.com/a/29597209/2966723. Licensed under Creative Commons Attribution-Share Alike diff --git a/stlearn/plotting/trajectory/tree_plot_simple.py b/stlearn/plotting/trajectory/tree_plot_simple.py new file mode 100644 index 00000000..3b2395fd --- /dev/null +++ b/stlearn/plotting/trajectory/tree_plot_simple.py @@ -0,0 +1,216 @@ +from matplotlib import pyplot as plt +from PIL import Image +import pandas as pd +import matplotlib +import numpy as np +import networkx as nx +import math +import random +from stlearn._compat import Literal +from typing import Optional, Union +from anndata import AnnData +import warnings +import io +from copy import deepcopy +from stlearn.utils import _read_graph + + +def tree_plot_simple( + adata: AnnData, + library_id: str = None, + figsize: Union[float, int] = (10, 4), + data_alpha: float = 1.0, + use_label: str = "louvain", + spot_size: Union[float, int] = 50, + fontsize: int = 6, + piesize: float = 0.15, + zoom: float = 0.1, + name: str = None, + output: str = None, + dpi: int = 180, + show_all: bool = False, + show_plot: bool = True, + ncols: int = 4, + copy: bool = False, +) -> Optional[AnnData]: + """\ + Hierarchical tree plot represent for the global spatial trajectory inference. + + Parameters + ---------- + adata + Annotated data matrix. + library_id + Library id stored in AnnData. + use_label + Use label result of cluster method. + figsize + Change figure size. + data_alpha + Opacity of the spot. + fontsize + Choose font size. + piesize + Choose the size of cropped image. + zoom + Choose zoom factor. + show_all + Show all cropped image or not. + show_legend + Show legend or not. + dpi + Set dpi as the resolution for the plot. + copy + Return a copy instead of writing to adata. + Returns + ------- + Nothing + """ + + G = _read_graph(adata, "PTS_graph") + + if library_id is None: + library_id = list(adata.uns["spatial"].keys())[0] + + G.remove_node(9999) + + start_nodes = [] + disconnected_nodes = [] + for node in G.in_degree(): + if node[1] == 0: + start_nodes.append(node[0]) + + for node in G.out_degree(): + if node[1] == 0: + disconnected_nodes.append(node[0]) + + start_nodes = list(set(start_nodes) - set(disconnected_nodes)) + start_nodes.sort() + + nrows = math.ceil(len(start_nodes) / ncols) + + superfig, axs = plt.subplots(nrows, ncols, figsize=figsize) + axs = axs.ravel() + + for idx in range(0, nrows * ncols): + try: + generate_tree_viz( + adata, use_label, G, axs[idx], starter_node=start_nodes[idx] + ) + except: + axs[idx] = axs[idx].axis("off") + + if name is None: + name = use_label + + if output is not None: + superfig.savefig( + output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0 + ) + + if show_plot == True: + plt.show() + + +def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5): + """ + From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + Licensed under Creative Commons Attribution-Share Alike + + If the graph is a tree this will return the positions to plot this in a + hierarchical layout. + + G: the graph (must be a tree) + + root: the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + + width: horizontal space allocated for this branch - avoids overlap with other branches + + vert_gap: gap between levels of hierarchy + + vert_loc: vertical location of root + + xcenter: horizontal location of root + """ + if not nx.is_tree(G): + raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") + + if root is None: + if isinstance(G, nx.DiGraph): + root = next( + iter(nx.topological_sort(G)) + ) # allows back compatibility with nx version 1.11 + else: + root = random.choice(list(G.nodes)) + + def _hierarchy_pos( + G, root, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None + ): + """ + see hierarchy_pos docstring for most arguments + + pos: a dict saying where all nodes go if they have been assigned + parent: parent of this branch. - only affects it if non-directed + + """ + + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.neighbors(root)) + if not isinstance(G, nx.DiGraph) and parent is not None: + children.remove(parent) + if len(children) != 0: + dx = width / len(children) + nextx = xcenter - width / 2 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos( + G, + child, + width=dx, + vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, + xcenter=nextx, + pos=pos, + parent=root, + ) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) + + +def generate_tree_viz(adata, use_label, G, axis, starter_node): + tmp_edges = [] + for edge in G.edges(): + if starter_node == edge[0]: + tmp_edges.append(edge) + tmp_D = nx.DiGraph() + tmp_D.add_edges_from(tmp_edges) + + pos = hierarchy_pos(tmp_D) + a = axis + + a.axis("off") + colors = [] + for n in tmp_D: + subset = adata.obs[adata.obs["sub_cluster_labels"] == str(n)] + colors.append(adata.uns[use_label + "_colors"][int(subset[use_label][0])]) + + nx.draw_networkx_edges( + tmp_D, + pos, + ax=a, + arrowstyle="-", + edge_color="#ADABAF", + connectionstyle="angle3,angleA=0,angleB=90", + ) + nx.draw_networkx_nodes(tmp_D, pos, node_color=colors, ax=a) + nx.draw_networkx_labels(tmp_D, pos, font_color="black", ax=a)