forked from samsydco/HBN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path6_ISC_test.py
184 lines (171 loc) · 7.01 KB
/
6_ISC_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python3
# 1) Compute many leave-one-out values(hundreds?), by randomly selecting subjects and correlating against the N-1 other subjects.
# 2) Compute many split-half values, by randomly selecting an N/2 split of subjects and correlating the mean timecourses of the halves.
# 3) Compute many pairwise values, by randomly selecting two subjects and correlating their timecourses.
# In all cases - time how long it takes to collect these values.
# then, compute f using subsets of each of these values
import time
import h5py
import tqdm
import numpy as np
import deepdish as dd
from scipy.stats import zscore
from scipy.spatial.distance import squareform
from random import shuffle
import matplotlib.pyplot as plt
from settings import *
# Only using 233 subj
subord = dd.io.load(metaphenopath+'pheno_2019-05-28.h5',['/subs'])[0]
ISCf = ISCpath+'old_ISC/ISC_2019-05-28.h5'
n_subj = len(subord)
n_vox = 5
ISCversions = ['Loo','SH','Pair']
# Some math to convert between correlation values
def corr_convert(r,N,corrtype='SH'):
if corrtype == 'SH':
f = 2*r/(N*(1-r))
elif corrtype == 'Pair':
f = r/(1-r)
else:
f = (N*np.square(r)+np.sqrt((N**2)*np.power(r,4,dtype=np.float16)+4*np.square(r)*(N-1)*(1-np.square(r))))/(2*(N-1)*(1-np.square(r)))
r_pw = f/(f+1)
r_sh = (N*f)/(N*f+2)
r_loo = (np.sqrt(N-1)*f)/(np.sqrt(f+1)*np.sqrt((N-1)*f+1))
return r_pw,r_sh,r_loo
for task in ['DM','TP']:
print(task)
non_nan_verts = np.where(~np.isnan(np.concatenate([dd.io.load(subord[0],['/'+task+'/L'])[0], dd.io.load(subord[0],['/'+task+'/R'])[0]], axis=0))[:,0])[0]
dictall = {k:{'ISC':None,'Time':None} for k in ISCversions}
_,n_time = dd.io.load(subord[0],['/'+task+'/L'])[0].shape
dictall['verts'] = non_nan_verts[np.random.choice(len(non_nan_verts),n_vox,replace=False)]
D = np.empty((n_vox,n_time,n_subj),dtype='float16')
keys = list(h5py.File(ISCf)[task]['data'].keys())
for key in keys:
D[:,0+250*int(key):250+250*int(key),:] = dd.io.load(ISCf,['/'+task+'/data/'+key], sel=dd.aslice[dictall['verts'],:,:])[0]
# Loo
print('Leave one out...')
dictall['Loo']['ISC'] = np.zeros((n_vox,n_subj),dtype='float16')
dictall['Loo']['Time'] = []
# Loop across choice of leave-one-out subject
for loo_subj in tqdm.tqdm(range(n_subj)):
t = time.process_time()
group = np.zeros((n_vox,n_time),dtype='float16')
groupn = np.ones((n_vox,n_time),dtype='int')*n_subj-1
for i in range(n_subj):
if i != loo_subj:
group = np.nansum(np.stack((group,D[:,:,i])),axis=0)
nanverts = np.argwhere(np.isnan(D[:,:,i]))
groupn[nanverts[:, 0],nanverts[:,1]] = groupn[nanverts[:,0],nanverts[:,1]]-1
group = zscore(group/groupn,axis=1)
subj = zscore(D[:, :, loo_subj],axis=1)
dictall['Loo']['ISC'][:,loo_subj] = np.sum(np.multiply(group,subj),axis=1)/(n_time-1)
dictall['Loo']['Time'].append(time.process_time() - t)
# SH
print('Split Halves...')
dictall['SH']['ISC'] = np.zeros((n_vox,n_subj),dtype='float16')
dictall['SH']['Time'] = []
subjl = np.arange(n_subj)
for sh in tqdm.tqdm(range(n_subj)):
t = time.process_time()
shuffle(subjl)
groups = np.zeros((2,n_vox,n_time),dtype='float16')
for h in [0,1]:
group = np.zeros((n_vox,n_time),dtype='float16')
groupn = np.ones((n_vox,n_time),dtype='int')*n_subj
for i in subjl[0+n_subj//2*h:n_subj//2+n_subj//2*h]:
group = np.nansum(np.stack((group,D[:,:,i])),axis=0)
nanverts = np.argwhere(np.isnan(D[:,:,i]))
groupn[nanverts[:, 0],nanverts[:,1]] = groupn[nanverts[:,0],nanverts[:,1]]-1
groups[h] = zscore(group/groupn,axis=1)
dictall['SH']['ISC'][:,sh] = np.sum(np.multiply(groups[0],groups[1]),axis=1)/(n_time-1)
dictall['SH']['Time'].append(time.process_time() - t)
# Pairwise
print('Pairwise...')
dictall['Pair']['Time'] = []
for i in range(n_subj):
voxel_iscs = []
t = time.process_time()
for v in np.arange(n_vox):
voxel_data = D[v, :, :i+1].T
# Correlation matrix for all pairs of subjects (triangle)
iscs = squareform(np.corrcoef(voxel_data), checks=False)
voxel_iscs.append(iscs)
dictall['Pair']['Time'].append(time.process_time() - t)
dictall['Pair']['ISC'] = np.column_stack(voxel_iscs)
n_it = 10
for g in ISCversions:
i = dictall[g]
i['TimeCum'] = []
i['ISCCum'] = np.zeros((n_subj-1,n_vox,n_it))
i['f'] = np.zeros((n_subj-1,n_vox,n_it))
for s in np.arange(2,n_subj+1):
i['TimeCum'].append(np.sum(i['Time'][:s]))
for it in range(n_it):
if g != 'Pair':
randsubjs = np.random.choice(n_subj,s,replace=False)
i['ISCCum'][s-2,:,it] = np.nanmean(i['ISC'][:,randsubjs],axis=1)
else:
randsubjs = np.random.choice(int((n_subj*n_subj-n_subj)/2),int((s*s-s)/2),replace=False)
i['ISCCum'][s-2,:,it] = np.nanmean(i['ISC'][randsubjs,:],axis=0)
N = n_subj
r = i['ISCCum'][s-2]
if g == 'Pair':
i['f'][s-2] = r/(1-r)
if g == 'SH':
i['f'][s-2] = 2*r/(N*(1-r))
if g == 'Loo':
i['f'][s-2] = (N*np.square(r)+ \
np.sqrt((N**2)*np.power(r,4,dtype=np.float16)+4*np.square(r)*(N-1)*(1-np.square(r))))/(2*(N-1)*(1-np.square(r)))
figsubj = 200
fig = plt.figure()
for v in range(n_vox):
ax = fig.add_subplot(n_vox,1,v+1)
if v == 0:
ax.set_title(task)
for g in ISCversions:
i = dictall[g]
# plot Time vs Accuracy:
final_f = np.mean(i['f'][-1,v,:])
y = np.mean(i['f'][:figsubj,v,:]-final_f,axis=1)
error = np.max(i['f'][:figsubj,v,:]-final_f,axis=1)
ax.plot(i['TimeCum'][:figsubj],y,label=g)
ax.fill_between(i['TimeCum'][:figsubj], y-error, y+error,alpha=0.2)
ax.set_ylabel('f acc\nvert =\n'+str(dictall['verts'][v]),size=7)
#if g == 'Pair':
# ax.set_ylim(y[0]-error[0],y[0]+error[0])
ax.set_xlim(min(dictall['Pair']['TimeCum'][:figsubj]),max(dictall['SH']['TimeCum'][:figsubj]))
if v == n_vox-1:
ax.legend(bbox_to_anchor=(0.7,1,0.5,3.5))
ax.set_xlabel('Time [s]')
fig.savefig(figurepath+'TimeVsFacc_'+task+'_'+str(figsubj)+'.png', bbox_inches = "tight",dpi=300)
for figsubj in [25,50,n_subj-1]:
fig = plt.figure()
for v in range(n_vox):
ax0 = fig.add_subplot(n_vox+1,1,v+1)
if v == 0:
ax0.set_title(task)
for g in ISCversions:
i = dictall[g]
# plot subj vs f for vox:
ax0.plot(np.arange(1,figsubj+1),np.mean(i['f'][:figsubj,v,:],axis=1))
final_f = np.mean((dictall['SH']['f'][-1,v,:]+dictall['Loo']['f'][-1,v,:]+dictall['Pair']['f'][-1,v,:])/3)
ax0.plot([figsubj-1,figsubj],[final_f,final_f],'k-')
ax0.set_ylabel('f\nvert =\n'+str(dictall['verts'][v]),size=7)
ax0.set_xlim(2,figsubj)
if v == n_vox-1:
ax1 = fig.add_subplot(n_vox+1,1,n_vox+1)
ax1.plot(np.arange(1,figsubj+1),i['TimeCum'][:figsubj],label=g) # plot subj vs compute time for vox
ax1.set_xlim(2,figsubj)
ax1.legend(bbox_to_anchor=(0.7,1,0.5,3.5))
ax1.set_xlabel('Subjects')
ax1.set_ylabel('Time [s]')
fig.savefig(figurepath+'FvsCompT_'+task+'_'+str(figsubj)+'.png', bbox_inches = "tight",dpi=300)
with h5py.File(ISCpath+'ISC_test.h5') as hf:
grp = hf.create_group(task)
if type(i) == dict:
for g,i in dictall.items():
ds = grp.create_group(g)
for gi,ii in i.items():
ds.create_dataset(gi,data=ii)
else:
grp.create_dataset(g,data=i)