-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathsweep_robustness.py
218 lines (188 loc) · 4.8 KB
/
sweep_robustness.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# %%
import numpy as np
import scvelo as scv
import torch
from umap import UMAP
from sklearn.decomposition import PCA
from scipy.stats import mannwhitneyu
import wandb
from deepvelo.utils import velocity, velocity_confidence, update_dict
from deepvelo.utils.preprocess import autoset_coeff_s
from deepvelo.utils.plot import statplot, compare_plot
from deepvelo import train, Constants
hyperparameter_defaults = dict(
seed=123,
layers=[64, 64],
topC=30,
topG=20,
lr=0.001,
pearson_scale=18.0,
pp_hvg=2000,
pp_neighbors=30,
pp_pcs=30
# NOTE: add any hyperparameters you want to sweep here
)
run = wandb.init(config=hyperparameter_defaults, project="scFormer", reinit=True)
wargs = wandb.config
# fix random seeds for reproducibility
SEED = wargs.seed
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
scv.settings.verbosity = 3 # show errors(0), warnings(1), info(2), hints(3)
scv.settings.set_figure_params(
"scvelo", transparent=False
) # for beautified visualization
# %% [markdown]
# # Load DG data and preprocess
# %%
adata = scv.datasets.dentategyrus()
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=wargs.pp_hvg)
scv.pp.moments(adata, n_neighbors=wargs.pp_neighbors, n_pcs=wargs.pp_pcs)
# %% [markdown]
# # DeepVelo
# %%
# specific configs to overide the default configs, #NOTE: see train.py for complete args
configs = {
"name": "DeepVelo", # name of the experiment
"arch": {
"args": {
"layers": wargs.layers,
},
},
"data_loader": {
"args": {
"topC": wargs.topC,
"topG": wargs.topG,
},
},
"optimizer": {
"args": {
"lr": wargs.lr,
},
},
"loss": {
"args": {
"pearson_scale": wargs.pearson_scale,
"coeff_s": autoset_coeff_s(adata),
},
},
"trainer": {"verbosity": 0}, # increase verbosity to show training progress
}
configs = update_dict(Constants.default_configs, configs)
# %%
# initial velocity
velocity(adata, mask_zero=False)
trainer = train(adata, configs)
# %%
scv.tl.velocity_graph(adata, n_jobs=8)
# %%
# velocity plot
scv.pl.velocity_embedding_stream(
adata,
basis="umap",
color="clusters",
legend_fontsize=9,
dpi=150, # increase dpi for higher resolution
show=False,
)
# NOTE: may log the plot to wandb using wandb.log({"velocity_embedding_stream": wandb.Image(plt)})
# %%
scv.pl.velocity_embedding(
adata,
basis="umap",
arrow_length=6,
arrow_size=1.2,
dpi=150,
show=False,
)
# %%
scv.pl.velocity_embedding_grid(
adata,
basis="umap",
arrow_length=4,
# alpha=0.1,
arrow_size=2,
arrow_color="tab:blue",
dpi=150,
show=False,
)
# %%
# get kinetic_rates
if "cell_specific_alpha" in adata.layers:
all_rates = np.concatenate(
[
adata.layers["cell_specific_beta"],
adata.layers["cell_specific_gamma"],
adata.layers["cell_specific_alpha"],
],
axis=1,
)
else:
all_rates = np.concatenate(
[
adata.layers["cell_specific_beta"],
adata.layers["cell_specific_gamma"],
],
axis=1,
)
# pca and umap of all rates
rates_pca = PCA(n_components=30, random_state=SEED).fit_transform(all_rates)
adata.obsm["X_rates_pca"] = rates_pca
rates_umap = UMAP(
n_neighbors=60,
min_dist=0.6,
spread=0.9,
random_state=SEED,
).fit_transform(rates_pca)
adata.obsm["X_rates_umap"] = rates_umap
# %%
# plot kinetic rates umap
scv.pl.scatter(
adata,
basis="rates_umap",
# omit_velocity_fit=True,
add_outline="Granule mature, Granule immature, Neuroblast",
outline_width=(0.15, 0.3),
title="umap of cell-specific kinetic rates",
legend_loc="none",
dpi=150,
show=False,
)
# %%
# plot genes
scv.pl.velocity(
adata,
var_names=["Tmsb10", "Ppp3ca"],
basis="umap",
show=False,
)
# %%
# save adata for next steps
deepvelo_adata = adata.copy()
# %% [markdown]
# # Compare consistency score
# %%
vkey = "velocity"
method = "cosine"
velocity_confidence(deepvelo_adata, vkey=vkey, method=method)
deepvelo_adata.obs["overall_consistency"] = deepvelo_adata.obs[
f"{vkey}_confidence_{method}"
].copy()
# %%
vkey = "velocity"
method = "cosine"
scope_key = "clusters"
# 3. cosine similarity, compute within Celltype
velocity_confidence(deepvelo_adata, vkey=vkey, method=method, scope_key=scope_key)
deepvelo_adata.obs["celltype_consistency"] = deepvelo_adata.obs[
f"{vkey}_confidence_{method}"
].copy()
# NOTE: example of logging metrics
wandb.log(
{
"celltype_consistency": deepvelo_adata.obs["celltype_consistency"].mean(),
"overall_consistency": deepvelo_adata.obs["overall_consistency"].mean(),
}
)