-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtci_squidpy_supp_lib.py
239 lines (191 loc) · 12.6 KB
/
tci_squidpy_supp_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
save_image_ext = 'jpg'
def get_anndata_object(coordinates=None, labels=None, seed=None, label_name='label'):
"""Obtain an AnnData object from a set of spatial coordinates and labels.
Args:
coordinates (NumPy array, optional): Array of shape (N, 2), where N is the number of observations, such as cells. Defaults to None.
labels (list, optional): List of N values, which if not strings will get converted to them, corresponding to labels for the observations, such as phenotype. Defaults to None.
seed (int, optional): Random seed. Defaults to None.
label_name (str, optional): Name for the labels, such as 'phenotype'. Defaults to 'label'.
Returns:
AnnData: AnnData object.
dtype: Original datatype of the labels.
"""
# Import required libraries
from anndata import AnnData
import pandas as pd
# If no coordinates are input, generate 100 random ones
if coordinates is None:
from numpy.random import default_rng
rng = default_rng(seed)
coordinates = rng.uniform(0, 10, size=(100, 2))
# If no labels for the coordinates are input, generate 100 random ones; otherwise, check that the labels are strings as required, and if not, cast them as such
if labels is None:
from numpy.random import default_rng
rng = default_rng(seed)
labels = [y for y in map(lambda x: {5: 'leaf', 6: 'tree', 7: 'flower', 8: 'acorn'}[x], rng.integers(5, 9, size=(100,)))]
labels_type_orig = str
else:
labels_type_orig = type(labels[0])
if labels_type_orig != str:
labels = [str(x) for x in labels]
# Generate the AnnData object
adata = AnnData(None, obsm={'spatial': coordinates}, obs=pd.DataFrame(labels).rename({0: label_name}, axis='columns'))
# Return the AnnData object and the original datatype of the labels
return adata, labels_type_orig
def squidpy_scatter(adata, label_name='label', circle_size=50, image_pathname=None, dpi=150):
"""Generate a scatter plot using Squidpy's functionality and optionally save the figure to disk.
Args:
adata (AnnData): AnnData object describing the dataset, as e.g. generated by get_anndata_object().
label_name (str, optional): Name for the labels, such as 'phenotype'. Defaults to 'label'.
circle_size (int, optional): Value describing the marker size as detailed in sq.pl.spatial_scatter(). Defaults to 50.
image_pathname (str, optional): Path to the image to be generated, with or without the file extension. Defaults to None.
dpi (int, optional): Dots per inch for the figure that's optionally generated. Defaults to 150.
Returns:
AnnData: AnnData object.
figure: matplotlib.pyplot figure handle.
"""
# Import relevant library
import matplotlib.pyplot as plt
import squidpy as sq
# Generate the scatter plot
sq.pl.spatial_scatter(adata, shape=None, color=label_name, size=circle_size)
# Get the figure handle
scatter_plot = plt.gcf()
# Optionally save the image to disk
if image_pathname is not None:
scatter_plot.savefig(image_pathname, dpi=dpi, bbox_inches='tight')
# Return the updated AnnData object and the plot
return adata, scatter_plot
def squidpy_enrichment(adata, radius=3.0, n_neighs=6, radius_instead_of_knn=True, label_name='label', annotate=False, image_pathname=None, dpi=150, print_heatmap_data=False, return_heatmap_data=False, labels_type_orig=str, check_symmetric=False, n_jobs=1):
"""Run Squidpy's neighborhood enrichment analysis, generating the spatial graph, running the analysis, generating the heatmap of z-scores, optionally saving the figure to disk, and, importantly, extracting the z-scores from the AnnData object.
Args:
adata (AnnData): AnnData object describing the dataset.
radius (float, optional): Radius likely in units of the coordinates to use for the graph generation. See sq.gr.spatial_neighbors() or online documentation for more details. Defaults to 3.0.
label_name (str, optional): Name for the labels, such as 'phenotype'. Defaults to 'label'.
annotate (bool, optional): Whether to include the z-score values for each heatmap square. Defaults to False.
image_pathname (str, optional): Path to the image to be generated, with or without the file extension. Defaults to None.
dpi (int, optional): Dots per inch for the figure that's optionally generated. Defaults to 150.
print_heatmap_data (bool, optional): Whether to print the z-scores and corresponding row/col labels. Defaults to False.
return_heatmap_data (bool, optional): Whether to return the z-scores and row/col labels. Defaults to False.
labels_type_orig (dtype, optional): Original datatype of the labels. Defaults to str.
Returns:
Always:
AnnData: AnnData object.
figure: matplotlib.pyplot figure handle.
Optional (depending on return_heatmap_data):
NumPy array: M x M array of z-scores, where M is the number of unique labels.
list: List of unique labels.
"""
# Import relevant libraries
import squidpy as sq
import matplotlib.pyplot as plt
# Generate the spatial graph
if radius_instead_of_knn:
print('Generating radius-based graphs in Squidpy')
sq.gr.spatial_neighbors(adata, radius=radius, coord_type='generic')
else:
print('Generating k-nearest neighbors-based graphs in Squidpy')
sq.gr.spatial_neighbors(adata, n_neighs=n_neighs, coord_type='generic')
# Calculate the neighborhood enrichment
sq.gr.nhood_enrichment(adata, cluster_key=label_name)
# Generate the heatmap
sq.pl.nhood_enrichment(adata, cluster_key=label_name, annotate=annotate)
# Get the figure handle
enrichment_plot = plt.gcf()
# Optionally save the image to disk
if image_pathname is not None:
enrichment_plot.savefig(image_pathname, dpi=dpi, bbox_inches='tight')
# Extract the heatmap data
if print_heatmap_data or return_heatmap_data:
zscores = adata.uns[label_name + '_nhood_enrichment']['zscore']
if check_symmetric:
pass #### implement check for symmetry of zscores matrix!!
heatmap_rowcol_labels = [x.get_text() for x in enrichment_plot.axes[1].get_yticklabels()[::-1]]
# If custom labels were input and they were not already of the string dtype, convert the ordered, shortened list back to the original dtype
if (print_heatmap_data or return_heatmap_data) and (labels_type_orig != str):
heatmap_rowcol_labels = [labels_type_orig(x) for x in heatmap_rowcol_labels]
# Optionally print the heatmap data
if print_heatmap_data:
print('* Extracted z-scores:', zscores.round(2), sep='\n')
print('* Extracted row/col labels:', heatmap_rowcol_labels, sep='\n')
# Return the updated AnnData object and the plot
if return_heatmap_data:
return adata, enrichment_plot, zscores, heatmap_rowcol_labels
else:
return adata, enrichment_plot
def zscores_to_pvals(zscores):
"""Convert the z-scores to left and right P values.
Args:
zscores (NumPy array): M x M array of z-scores, where M is the number of unique labels.
Returns:
NumPy array: M x M array of left P values.
NumPy array: M x M array of right P values.
"""
# Import relevant library
import scipy.stats
# Calculate the left and right P values
pvals_left = scipy.stats.norm.cdf(zscores)
pvals_right = 1 - pvals_left # i.e., 1 - scipy.stats.norm.cdf(zscores)
# Return the P values
return pvals_left, pvals_right
def calculate_squidpy_pvals(coordinates=None, labels=None, seed=None, label_name='label', circle_size=50, image_path_prefix=None, dpi=150, radius=3.0, n_neighs=6, radius_instead_of_knn=True, annotate_heatmap=True, print_heatmap_data=False, close_figs=True, return_anndata_obj=False, n_jobs=1):
"""From input coordinates and corresponding labels, run Squidpy's neighborhood enrichment analysis and return the corresponding P values, plotting the scatter plot along the way.
Args:
coordinates (NumPy array, optional): Array of shape (N, 2), where N is the number of observations, such as cells. Defaults to None.
labels (list, optional): List of N values, which if not strings will get converted to them, corresponding to labels for the observations, such as phenotype. Defaults to None.
seed (int, optional): Random seed. Defaults to None.
label_name (str, optional): Name for the labels, such as 'phenotype'. Defaults to 'label'.
circle_size (int, optional): Value describing the marker size as detailed in sq.pl.spatial_scatter(). Defaults to 50.
image_path_prefix (str, optional): String prefix to the image filenames to be generated, optionally including the directory. If it includes part of the basename for the image file to be generated, be sure to end with a hyphen or the like. Defaults to None.
dpi (int, optional): Dots per inch for the figures that are optionally generated. Defaults to 150.
radius (float, optional): Radius likely in units of the coordinates to use for the graph generation. See sq.gr.spatial_neighbors() or online documentation for more details. Defaults to 3.0.
annotate_heatmap (bool, optional): Whether to include the z-score values for each heatmap square. Defaults to True.
print_heatmap_data (bool, optional): Whether to print the z-scores and corresponding row/col labels. Defaults to False.
close_figs (bool, optional): Whether to close the figures that are generated. If running en-masse, make sure this is True, but if running in a Jupyter notebook and examining the output, set this to False to ensure the images are output to the screen. Defaults to True.
return_anndata_obj (bool, optional): Whether to also return the anndata object, adata. Defaults to False.
Returns:
NumPy array: M x M array of left P values.
NumPy array: M x M array of right P values.
list: List of unique labels.
anndata (optional): Created anndata object.
"""
# Import relevant library
import matplotlib.pyplot as plt
# Determine the directory and/or basename prefix if it's desired that the scatter plot and heatmap be saved
if image_path_prefix is not None:
scatter_image_pathname = '{}scatter.{}'.format(image_path_prefix, save_image_ext)
heatmap_image_pathname = '{}heatmap.{}'.format(image_path_prefix, save_image_ext)
else:
scatter_image_pathname = None
heatmap_image_pathname = None
# Obtain an AnnData object from the coordinates and labels
adata, labels_type_orig = get_anndata_object(coordinates=coordinates, labels=labels, seed=seed, label_name=label_name)
# Plot a scatter plot of the data using Squidpy's functionality
try:
keep_going = True
adata, scatter_plot = squidpy_scatter(adata, label_name=label_name, circle_size=circle_size, image_pathname=scatter_image_pathname, dpi=dpi)
except ValueError:
print('ValueError for image {}; excepting this error and skipping the rest of the Squidpy P value calculation since related errors would be thrown'.format(scatter_image_pathname))
keep_going = False
if keep_going:
# Plot a heatmap using Squidpy's nhood_enrichment() method and return the result, which is a set of z-scores
adata, enrichment_plot, zscores, heatmap_rowcol_labels = squidpy_enrichment(adata, radius=radius, n_neighs=n_neighs, radius_instead_of_knn=radius_instead_of_knn, label_name=label_name, annotate=annotate_heatmap, image_pathname=heatmap_image_pathname, dpi=dpi, print_heatmap_data=print_heatmap_data, return_heatmap_data=True, labels_type_orig=labels_type_orig, n_jobs=n_jobs)
# Close the generated figures
if close_figs:
plt.close(scatter_plot)
plt.close(enrichment_plot)
# Convert the z-scores to P values
pvals_left, pvals_right = zscores_to_pvals(zscores)
else:
import os
logs_dir = os.path.join('.', 'output', 'logs')
if not os.path.exists(logs_dir):
os.mkdir(logs_dir)
with open(os.path.join(logs_dir, 'squidpy.log'), 'w') as f:
f.write('Squidpy is skipping scatter plot and neighborhood enrichment of images {} and {} due to failures\n'.format(scatter_image_pathname, heatmap_image_pathname))
pvals_left, pvals_right, heatmap_rowcol_labels = None, None, None
# Return the P values and the corresponding annotation labels
if not return_anndata_obj:
return pvals_left, pvals_right, heatmap_rowcol_labels
else:
return pvals_left, pvals_right, heatmap_rowcol_labels, adata