Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix current issues #273

Merged
merged 3 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@

If you have used stLearn in your research, please consider citing us:

> Pham, Duy, et al. "Robust mapping of spatiotemporal trajectories and cell–cell interactions in healthy and diseased tissues."
> Pham, Duy, et al. "Robust mapping of spatiotemporal trajectories and cell–cell interactions in healthy and diseased tissues."
> Nature Communications 14.1 (2023): 7739.
> [https://doi.org/10.1101/2020.05.31.125658](https://doi.org/10.1038/s41467-023-43120-6)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ bokeh>= 2.4.2
click>=8.0.4
leidenalg
louvain
numba<=0.57.1
numpy>=1.18,<1.22
Pillow>=9.0.1
scanpy>=1.8.2
Expand Down
10 changes: 5 additions & 5 deletions stlearn/adds/annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, List
from anndata import AnnData
from matplotlib import pyplot as plt
from pathlib import Path
Expand All @@ -7,11 +7,10 @@

def annotation(
adata: AnnData,
label_list: list,
label_list: List[str],
use_label: str = "louvain",
copy: bool = False,
) -> Optional[AnnData]:

"""\
Adding annotation for cluster

Expand All @@ -38,8 +37,9 @@ def annotation(
if len(label_list) != len(adata.obs[use_label].unique()):
raise ValueError("Please give the correct number of label list!")

adata.obs[use_label + "_anno"] = adata.obs[use_label]
adata.obs[use_label + "_anno"].cat.categories = label_list
adata.obs[use_label + "_anno"] = adata.obs[use_label].cat.rename_categories(
label_list
)

print("The annotation is added to adata.obs['" + use_label + "_anno" + "']")

Expand Down
18 changes: 13 additions & 5 deletions stlearn/image_preprocessing/image_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
def tiling(
adata: AnnData,
out_path: Union[Path, str] = "./tiling",
library_id: str = None,
library_id: Union[str, None] = None,
crop_size: int = 40,
target_size: int = 299,
img_fmt: str = "JPEG",
verbose: bool = False,
copy: bool = False,
) -> Optional[AnnData]:
Expand Down Expand Up @@ -78,17 +79,24 @@ def tiling(
(imagecol_left, imagerow_down, imagecol_right, imagerow_up)
)
tile.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
tile.resize((target_size, target_size))
tile = tile.resize((target_size, target_size))
tile_name = str(imagecol) + "-" + str(imagerow) + "-" + str(crop_size)
out_tile = Path(out_path) / (tile_name + ".jpeg")
tile_names.append(str(out_tile))

if img_fmt == "JPEG":
out_tile = Path(out_path) / (tile_name + ".jpeg")
tile_names.append(str(out_tile))
tile.save(out_tile, "JPEG")
else:
out_tile = Path(out_path) / (tile_name + ".png")
tile_names.append(str(out_tile))
tile.save(out_tile, "PNG")

if verbose:
print(
"generate tile at location ({}, {})".format(
str(imagecol), str(imagerow)
)
)
tile.save(out_tile, "JPEG")

pbar.update(1)

Expand Down
1 change: 1 addition & 0 deletions stlearn/plotting/trajectory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .pseudotime_plot import pseudotime_plot
from .local_plot import local_plot
from .tree_plot_simple import tree_plot_simple
from .tree_plot import tree_plot
from .transition_markers_plot import transition_markers_plot
from .DE_transition_plot import DE_transition_plot
Expand Down
2 changes: 0 additions & 2 deletions stlearn/plotting/trajectory/tree_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def tree_plot(
ncols: int = 4,
copy: bool = False,
) -> Optional[AnnData]:

"""\
Hierarchical tree plot represent for the global spatial trajectory inference.
Expand Down Expand Up @@ -114,7 +113,6 @@ def tree_plot(


def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5):

"""
From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
Licensed under Creative Commons Attribution-Share Alike
Expand Down
216 changes: 216 additions & 0 deletions stlearn/plotting/trajectory/tree_plot_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from matplotlib import pyplot as plt
from PIL import Image
import pandas as pd
import matplotlib
import numpy as np
import networkx as nx
import math
import random
from stlearn._compat import Literal
from typing import Optional, Union
from anndata import AnnData
import warnings
import io
from copy import deepcopy
from stlearn.utils import _read_graph


def tree_plot_simple(
adata: AnnData,
library_id: str = None,
figsize: Union[float, int] = (10, 4),
data_alpha: float = 1.0,
use_label: str = "louvain",
spot_size: Union[float, int] = 50,
fontsize: int = 6,
piesize: float = 0.15,
zoom: float = 0.1,
name: str = None,
output: str = None,
dpi: int = 180,
show_all: bool = False,
show_plot: bool = True,
ncols: int = 4,
copy: bool = False,
) -> Optional[AnnData]:
"""\
Hierarchical tree plot represent for the global spatial trajectory inference.

Parameters
----------
adata
Annotated data matrix.
library_id
Library id stored in AnnData.
use_label
Use label result of cluster method.
figsize
Change figure size.
data_alpha
Opacity of the spot.
fontsize
Choose font size.
piesize
Choose the size of cropped image.
zoom
Choose zoom factor.
show_all
Show all cropped image or not.
show_legend
Show legend or not.
dpi
Set dpi as the resolution for the plot.
copy
Return a copy instead of writing to adata.
Returns
-------
Nothing
"""

G = _read_graph(adata, "PTS_graph")

if library_id is None:
library_id = list(adata.uns["spatial"].keys())[0]

G.remove_node(9999)

start_nodes = []
disconnected_nodes = []
for node in G.in_degree():
if node[1] == 0:
start_nodes.append(node[0])

for node in G.out_degree():
if node[1] == 0:
disconnected_nodes.append(node[0])

start_nodes = list(set(start_nodes) - set(disconnected_nodes))
start_nodes.sort()

nrows = math.ceil(len(start_nodes) / ncols)

superfig, axs = plt.subplots(nrows, ncols, figsize=figsize)
axs = axs.ravel()

for idx in range(0, nrows * ncols):
try:
generate_tree_viz(
adata, use_label, G, axs[idx], starter_node=start_nodes[idx]
)
except:
axs[idx] = axs[idx].axis("off")

if name is None:
name = use_label

if output is not None:
superfig.savefig(
output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0
)

if show_plot == True:
plt.show()


def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5):
"""
From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
Licensed under Creative Commons Attribution-Share Alike

If the graph is a tree this will return the positions to plot this in a
hierarchical layout.

G: the graph (must be a tree)

root: the root node of current branch
- if the tree is directed and this is not given,
the root will be found and used
- if the tree is directed and this is given, then
the positions will be just for the descendants of this node.
- if the tree is undirected and not given,
then a random choice will be used.

width: horizontal space allocated for this branch - avoids overlap with other branches

vert_gap: gap between levels of hierarchy

vert_loc: vertical location of root

xcenter: horizontal location of root
"""
if not nx.is_tree(G):
raise TypeError("cannot use hierarchy_pos on a graph that is not a tree")

if root is None:
if isinstance(G, nx.DiGraph):
root = next(
iter(nx.topological_sort(G))
) # allows back compatibility with nx version 1.11
else:
root = random.choice(list(G.nodes))

def _hierarchy_pos(
G, root, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None
):
"""
see hierarchy_pos docstring for most arguments

pos: a dict saying where all nodes go if they have been assigned
parent: parent of this branch. - only affects it if non-directed

"""

if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.neighbors(root))
if not isinstance(G, nx.DiGraph) and parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(
G,
child,
width=dx,
vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap,
xcenter=nextx,
pos=pos,
parent=root,
)
return pos

return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)


def generate_tree_viz(adata, use_label, G, axis, starter_node):
tmp_edges = []
for edge in G.edges():
if starter_node == edge[0]:
tmp_edges.append(edge)
tmp_D = nx.DiGraph()
tmp_D.add_edges_from(tmp_edges)

pos = hierarchy_pos(tmp_D)
a = axis

a.axis("off")
colors = []
for n in tmp_D:
subset = adata.obs[adata.obs["sub_cluster_labels"] == str(n)]
colors.append(adata.uns[use_label + "_colors"][int(subset[use_label][0])])

nx.draw_networkx_edges(
tmp_D,
pos,
ax=a,
arrowstyle="-",
edge_color="#ADABAF",
connectionstyle="angle3,angleA=0,angleB=90",
)
nx.draw_networkx_nodes(tmp_D, pos, node_color=colors, ax=a)
nx.draw_networkx_labels(tmp_D, pos, font_color="black", ax=a)
Loading