-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01.deconv.py
68 lines (65 loc) · 2.56 KB
/
01.deconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# %% imports and definition
import os
import shutil
import dask as da
import numpy as np
import pandas as pd
import plotly.express as px
import xarray as xr
from minian.cnmf import update_temporal
IN_PATH = "./intermediate/concat"
OUT_PATH = "./intermediate/deconv"
FIG_PATH = "./figs/deconv"
MINIAN_INT = "./minian_int"
PARAM_NCELL_PLT = 10
# %% run deconvolution
if __name__ == "__main__":
shutil.rmtree(MINIAN_INT, ignore_errors=True)
os.environ["MINIAN_INTERMEDIATE"] = MINIAN_INT
os.makedirs(IN_PATH, exist_ok=True)
os.makedirs(OUT_PATH, exist_ok=True)
os.makedirs(FIG_PATH, exist_ok=True)
os.makedirs(MINIAN_INT, exist_ok=True)
sig_master = xr.open_dataarray(os.path.join(IN_PATH, "sig_master.nc")).rename(
{"master_uid": "unit_id"}
)
A_master = xr.open_dataarray(os.path.join(IN_PATH, "A_master.nc")).rename(
{"master_uid": "unit_id"}
)
C_ls = []
S_ls = []
for anm, sig_anm in sig_master.groupby("animal"):
A_anm = A_master.sel(animal=anm).chunk(
{"height": -1, "width": -1, "unit_id": 1}
)
sig_anm = sig_anm.squeeze().chunk({"frame": -1, "unit_id": 1})
with da.config.set(scheduler="processes"):
C_anm, S_anm, b0, c0, g, mask = update_temporal(
A=A_anm,
C=sig_anm,
YrA=sig_anm,
noise_freq=0.1,
jac_thres=0.8,
sparse_penal=0.5,
)
C_anm = C_anm.assign_coords(session=sig_anm.coords["session"], animal=anm)
S_anm = S_anm.assign_coords(session=sig_anm.coords["session"], animal=anm)
np.random.seed(42)
uid_plt = np.sort(np.random.choice(C_anm.coords["unit_id"], PARAM_NCELL_PLT))
C_df = C_anm.sel(unit_id=uid_plt).to_series().rename("val").reset_index()
S_df = S_anm.sel(unit_id=uid_plt).to_series().rename("val").reset_index()
sig_df = sig_anm.sel(unit_id=uid_plt).to_series().rename("val").reset_index()
C_df["var"] = "C"
S_df["var"] = "S"
sig_df["var"] = "sig"
plt_df = pd.concat([C_df, S_df, sig_df], ignore_index=True)
fig = px.line(plt_df, x="frame", y="val", color="var", facet_row="unit_id")
fig.update_yaxes(matches=None)
fig.update_layout(height=200 * PARAM_NCELL_PLT)
fig.write_html(os.path.join(FIG_PATH, "{}.html".format(anm)))
C_ls.append(C_anm)
S_ls.append(S_anm)
C = xr.concat(C_ls, "animal")
S = xr.concat(S_ls, "animal")
C.to_netcdf(os.path.join(OUT_PATH, "C.nc"))
S.to_netcdf(os.path.join(OUT_PATH, "S.nc"))