Skip to content

Commit

Permalink
fix: fix current issues
Browse files Browse the repository at this point in the history
  • Loading branch information
duypham2108 committed Dec 28, 2023
1 parent 1e12664 commit 79700e0
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@

If you have used stLearn in your research, please consider citing us:

> Pham, Duy, et al. "Robust mapping of spatiotemporal trajectories and cell–cell interactions in healthy and diseased tissues."
> Pham, Duy, et al. "Robust mapping of spatiotemporal trajectories and cell–cell interactions in healthy and diseased tissues."
> Nature Communications 14.1 (2023): 7739.
> [https://doi.org/10.1101/2020.05.31.125658](https://doi.org/10.1038/s41467-023-43120-6)
10 changes: 5 additions & 5 deletions stlearn/adds/annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, List
from anndata import AnnData
from matplotlib import pyplot as plt
from pathlib import Path
Expand All @@ -7,11 +7,10 @@

def annotation(
adata: AnnData,
label_list: list,
label_list: List[str],
use_label: str = "louvain",
copy: bool = False,
) -> Optional[AnnData]:

"""\
Adding annotation for cluster
Expand All @@ -38,8 +37,9 @@ def annotation(
if len(label_list) != len(adata.obs[use_label].unique()):
raise ValueError("Please give the correct number of label list!")

adata.obs[use_label + "_anno"] = adata.obs[use_label]
adata.obs[use_label + "_anno"].cat.categories = label_list
adata.obs[use_label + "_anno"] = adata.obs[use_label].cat.rename_categories(
label_list
)

print("The annotation is added to adata.obs['" + use_label + "_anno" + "']")

Expand Down
4 changes: 2 additions & 2 deletions stlearn/image_preprocessing/image_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def tiling(
adata: AnnData,
out_path: Union[Path, str] = "./tiling",
library_id: str = None,
library_id: Union[str, None] = None,
crop_size: int = 40,
target_size: int = 299,
verbose: bool = False,
Expand Down Expand Up @@ -78,7 +78,7 @@ def tiling(
(imagecol_left, imagerow_down, imagecol_right, imagerow_up)
)
tile.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
tile.resize((target_size, target_size))
tile = tile.resize((target_size, target_size))
tile_name = str(imagecol) + "-" + str(imagerow) + "-" + str(crop_size)
out_tile = Path(out_path) / (tile_name + ".jpeg")
tile_names.append(str(out_tile))
Expand Down
1 change: 1 addition & 0 deletions stlearn/plotting/trajectory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .pseudotime_plot import pseudotime_plot
from .local_plot import local_plot
from .tree_plot_simple import tree_plot_simple
from .tree_plot import tree_plot
from .transition_markers_plot import transition_markers_plot
from .DE_transition_plot import DE_transition_plot
Expand Down
2 changes: 0 additions & 2 deletions stlearn/plotting/trajectory/tree_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def tree_plot(
ncols: int = 4,
copy: bool = False,
) -> Optional[AnnData]:

"""\
Hierarchical tree plot represent for the global spatial trajectory inference.
Expand Down Expand Up @@ -114,7 +113,6 @@ def tree_plot(


def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5):

"""
From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
Licensed under Creative Commons Attribution-Share Alike
Expand Down
216 changes: 216 additions & 0 deletions stlearn/plotting/trajectory/tree_plot_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from matplotlib import pyplot as plt
from PIL import Image
import pandas as pd
import matplotlib
import numpy as np
import networkx as nx
import math
import random
from stlearn._compat import Literal
from typing import Optional, Union
from anndata import AnnData
import warnings
import io
from copy import deepcopy
from stlearn.utils import _read_graph


def tree_plot_simple(
adata: AnnData,
library_id: str = None,
figsize: Union[float, int] = (10, 4),
data_alpha: float = 1.0,
use_label: str = "louvain",
spot_size: Union[float, int] = 50,
fontsize: int = 6,
piesize: float = 0.15,
zoom: float = 0.1,
name: str = None,
output: str = None,
dpi: int = 180,
show_all: bool = False,
show_plot: bool = True,
ncols: int = 4,
copy: bool = False,
) -> Optional[AnnData]:
"""\
Hierarchical tree plot represent for the global spatial trajectory inference.
Parameters
----------
adata
Annotated data matrix.
library_id
Library id stored in AnnData.
use_label
Use label result of cluster method.
figsize
Change figure size.
data_alpha
Opacity of the spot.
fontsize
Choose font size.
piesize
Choose the size of cropped image.
zoom
Choose zoom factor.
show_all
Show all cropped image or not.
show_legend
Show legend or not.
dpi
Set dpi as the resolution for the plot.
copy
Return a copy instead of writing to adata.
Returns
-------
Nothing
"""

G = _read_graph(adata, "PTS_graph")

if library_id is None:
library_id = list(adata.uns["spatial"].keys())[0]

G.remove_node(9999)

start_nodes = []
disconnected_nodes = []
for node in G.in_degree():
if node[1] == 0:
start_nodes.append(node[0])

for node in G.out_degree():
if node[1] == 0:
disconnected_nodes.append(node[0])

start_nodes = list(set(start_nodes) - set(disconnected_nodes))
start_nodes.sort()

nrows = math.ceil(len(start_nodes) / ncols)

superfig, axs = plt.subplots(nrows, ncols, figsize=figsize)
axs = axs.ravel()

for idx in range(0, nrows * ncols):
try:
generate_tree_viz(
adata, use_label, G, axs[idx], starter_node=start_nodes[idx]
)
except:
axs[idx] = axs[idx].axis("off")

if name is None:
name = use_label

if output is not None:
superfig.savefig(
output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0
)

if show_plot == True:
plt.show()


def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5):
"""
From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
Licensed under Creative Commons Attribution-Share Alike
If the graph is a tree this will return the positions to plot this in a
hierarchical layout.
G: the graph (must be a tree)
root: the root node of current branch
- if the tree is directed and this is not given,
the root will be found and used
- if the tree is directed and this is given, then
the positions will be just for the descendants of this node.
- if the tree is undirected and not given,
then a random choice will be used.
width: horizontal space allocated for this branch - avoids overlap with other branches
vert_gap: gap between levels of hierarchy
vert_loc: vertical location of root
xcenter: horizontal location of root
"""
if not nx.is_tree(G):
raise TypeError("cannot use hierarchy_pos on a graph that is not a tree")

if root is None:
if isinstance(G, nx.DiGraph):
root = next(
iter(nx.topological_sort(G))
) # allows back compatibility with nx version 1.11
else:
root = random.choice(list(G.nodes))

def _hierarchy_pos(
G, root, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None
):
"""
see hierarchy_pos docstring for most arguments
pos: a dict saying where all nodes go if they have been assigned
parent: parent of this branch. - only affects it if non-directed
"""

if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.neighbors(root))
if not isinstance(G, nx.DiGraph) and parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(
G,
child,
width=dx,
vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap,
xcenter=nextx,
pos=pos,
parent=root,
)
return pos

return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)


def generate_tree_viz(adata, use_label, G, axis, starter_node):
tmp_edges = []
for edge in G.edges():
if starter_node == edge[0]:
tmp_edges.append(edge)
tmp_D = nx.DiGraph()
tmp_D.add_edges_from(tmp_edges)

pos = hierarchy_pos(tmp_D)
a = axis

a.axis("off")
colors = []
for n in tmp_D:
subset = adata.obs[adata.obs["sub_cluster_labels"] == str(n)]
colors.append(adata.uns[use_label + "_colors"][int(subset[use_label][0])])

nx.draw_networkx_edges(
tmp_D,
pos,
ax=a,
arrowstyle="-",
edge_color="#ADABAF",
connectionstyle="angle3,angleA=0,angleB=90",
)
nx.draw_networkx_nodes(tmp_D, pos, node_color=colors, ax=a)
nx.draw_networkx_labels(tmp_D, pos, font_color="black", ax=a)

0 comments on commit 79700e0

Please sign in to comment.