-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspike_train_distances.py
251 lines (199 loc) · 7.45 KB
/
spike_train_distances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#!/usr/bin/env python
# -*- coding:utf-8 -*-
###
# Created Date: 2023-08-03 17:38:57
# Author: Gehua Ma
# -----
# Last Modified: 2024-03-08 14:22:01
# Modified By: Gehua Ma
# -----
###
import numpy as np
import torch
import math
from numba import jit, njit
from tqdm import tqdm
def onehot2firingtime(spiketrain):
raw_fts = np.arange(spiketrain.size(0)) * 100/3/1000 * spiketrain.numpy()
fts = raw_fts[raw_fts.nonzero()]
return fts
def victor_purpura_dist(tli, tlj, cost=1):
r"""
d=spkd(tli,tlj,cost) calculates the "spike time" distance
as defined [DA2003]_ for a single free parameter,
the cost per unit of time to move a spike.
:param tli: vector of spike times for first spike train
:param tlj: vector of spike times for second spike train
:keyword cost: cost per unit time to move a spike
:returns: spike distance metric
Translated to Python by Nicolas Jimenez from Matlab code by Daniel Reich.
.. [DA2003] Aronov, Dmitriy. "Fast algorithm for the metric-space analysis
of simultaneous responses of multiple single neurons." Journal
of Neuroscience Methods 124.2 (2003): 175-179.
Here, the distance is 1 because there is one extra spike to be deleted at
the end of the the first spike train:
>>> spike_time([1,2,3,4],[1,2,3],cost=1)
1
Here the distance is 1 because we shift the first spike by 0.2,
leave the second alone, and shift the third one by 0.2,
adding up to 0.4:
>>> spike_time([1.2,2,3.2],[1,2,3],cost=1)
0.4
Here the third spike is adjusted by 0.5, but since the cost
per unit time is 0.5, the distances comes out to 0.25:
>>> spike_time([1,2,3,4],[1,2,3,3.5],cost=0.5)
0.25
"""
nspi = len(tli)
nspj = len(tlj)
if cost == 0:
d = abs(nspi-nspj)
return d
elif cost == np.Inf:
d = nspi + nspj
return d
scr = np.zeros((nspi+1, nspj+1))
# INITIALIZE MARGINS WITH COST OF ADDING A SPIKE
scr[:, 0] = np.arange(0, nspi+1)
scr[0, :] = np.arange(0, nspj+1)
if nspi and nspj:
for i in range(1, nspi+1):
for j in range(1, nspj+1):
scr[i, j] = min(
[
scr[i-1, j]+1,
scr[i, j-1]+1,
scr[i-1, j-1] + cost*abs(tli[i-1] - tlj[j-1])
]
)
d = scr[nspi, nspj]
return d
class ExpDecay():
"""
Exponentially decaying function with additive method.
Useful for efficiently computing Van Rossum distance.
"""
def __init__(self, k=None, dt=0.0001):
self.sic_val = 0.0
self.dt = dt
self.k = k
self.decay_factor = np.exp(-dt*k)
def update(self, V=0):
self.sic_val = self.sic_val * self.decay_factor
return self.sic_val
def spike(self):
self.sic_val += 1
return self.sic_val
def reset(self):
self.sic_val = 0
def van_rossum_dist(st_0, st_1, tc=10, bin_width=0.001, t_extra=1):
"""
Calculates the Van Rossum distance between spike trains
as defined in [VR2001]_. Note that the default parameters
are optimized for inputs in units of seconds.
:param st_0: array of spike times for first spike train
:param st_1: array of spike times for second spike train
:param bin_width: precision in units of time to compute integral
:param t_extra: how much beyond max time do we keep integrating until?
This is necessary because the integral cannot in practice be evaluated between
:math:`t=0` and :math:`t=inf`.
.. [VR2001] van Rossum, Mark CW. "A novel spike distance."
Neural Computation 13.4 (2001): 751-763.
"""
# by default, we assume spike times are in seconds,
# keep integrating up to 0.5 s past last spike
t_max = max([st_0[-1], st_1[-1]]) + t_extra
# t_min = min(st_0[0],st_0[0])
t_range = np.arange(0, t_max, bin_width)
# we use a spike induced current to perform the computation
sic = ExpDecay(k=1.0/tc, dt=bin_width)
f_0 = t_range * 0.0
f_1 = t_range * 0.0
# we make copies of these arrays, since we are going to "pop" them
s_0 = list(st_0[:])
s_1 = list(st_1[:])
for (st, f) in [(s_0, f_0), (s_1, f_1)]:
# set the internal value to zero
sic.reset()
for (t_ind, t) in enumerate(t_range):
f[t_ind] = sic.update()
if len(st) > 0:
if t > st[0]:
f[t_ind] = sic.spike()
st.pop(0)
d = np.sqrt((bin_width / tc) * np.linalg.norm((f_0-f_1), 1))
return d
def find_corner_spikes(t, train, ibegin, ti, te):
r"""
Return the times (t1,t2) of the spikes in train[ibegin:]
such that t1 < t and t2 >= t.
"""
if(ibegin == 0):
tprev = ti
else:
tprev = train[ibegin-1]
for idts, ts in enumerate(train[ibegin:]):
if(ts >= t):
return np.array([tprev, ts]), idts+ibegin
tprev = ts
return np.array([train[-1], te]), idts+ibegin
def bivariate_spike_distance(t1, t2, ti, te, N):
r"""
This Python code (including all further comments) was written by Jeremy Fix (see http://jeremy.fix.free.fr/),
based on Matlab code written by Thomas Kreuz.
The SPIKE-distance is described in this paper:
.. [KT2013] Kreuz T, Chicharro D, Houghton C, Andrzejak RG, Mormann F:
Monitoring spike train synchrony.
J Neurophysiol 109, 1457-1472 (2013).
Computes the bivariate SPIKE distance of Kreuz et al. (2012)
:param t1: 1D array with the spiking times of two neurons.
:param t2: 1D array with the spiking times of two neurons.
:param ti: beginning of time interval.
:param te: end of time intervals.
:param N: number of samples.
:returns: Array of the values of the distance between time ti and te with N samples.
.. note::
The arrays t1, t2 and values ti, te are unit less
"""
t = np.linspace(ti+(te-ti)/N, te, N)
d = np.zeros(t.shape)
t1 = np.insert(t1, 0, ti)
t1 = np.append(t1, te)
t2 = np.insert(t2, 0, ti)
t2 = np.append(t2, te)
corner_spikes = np.zeros((N, 5))
ibegin_t1 = 0
ibegin_t2 = 0
corner_spikes[:, 0] = t
for itc, tc in enumerate(t):
corner_spikes[itc, 1: 3], ibegin_t1 = find_corner_spikes(tc, t1, ibegin_t1, ti, te)
corner_spikes[itc, 3: 5], ibegin_t2 = find_corner_spikes(tc, t2, ibegin_t2, ti, te)
# print corner_spikes
xisi = np.zeros((N, 2))
xisi[:, 0] = corner_spikes[:, 2] - corner_spikes[:, 1]
xisi[:, 1] = corner_spikes[:, 4] - corner_spikes[:, 3]
norm_xisi = np.sum(xisi, axis=1)**2.0
dp1 = np.min(
np.fabs(np.tile(t2, (N, 1)) - np.tile(np.reshape(corner_spikes[:, 1], (N, 1)), t2.size)),
axis=1
)
df1 = np.min(
np.fabs(np.tile(t2, (N, 1)) - np.tile(np.reshape(corner_spikes[:, 2], (N, 1)), t2.size)),
axis=1
)
dp2 = np.min(
np.fabs(np.tile(t1, (N, 1)) - np.tile(np.reshape(corner_spikes[:, 3], (N, 1)), t1.size)),
axis=1
)
df2 = np.min(
np.fabs(np.tile(t1, (N, 1)) - np.tile(np.reshape(corner_spikes[:, 4], (N, 1)), t1.size)),
axis=1
)
xp1 = t - corner_spikes[:, 1]
xf1 = corner_spikes[:, 2] - t
xp2 = t - corner_spikes[:, 3]
xf2 = corner_spikes[:, 4] - t
S1 = (dp1 * xf1 + df1 * xp1)/xisi[:, 0]
S2 = (dp2 * xf2 + df2 * xp2)/xisi[:, 1]
d = (S1 * xisi[:, 1] + S2 * xisi[:, 0]) / (norm_xisi/2.0)
return t, d