Skip to content

Commit

Permalink
- Change STC calculation from np.matrix to np.array
Browse files Browse the repository at this point in the history
, use np.dot
- Change plotLNP to match PEP8
- Decrease time step to make calculations faster
  • Loading branch information
ycanerol committed May 17, 2017
1 parent 17a24b9 commit ca815ea
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 42 deletions.
40 changes: 22 additions & 18 deletions LNP_model
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ Created on Tue May 9 18:11:51 2017
@author: ycan
"""
#%%
# %%
import numpy as np
from scipy.stats.mstats import mquantiles

total_frames = 10000
dt = 0.001 # Time step
total_frames = 100000
dt = 0.01 # Time step
t = np.arange(0, total_frames*dt, dt) # Time vector
filter_time = .6 # The longest feature RGCs respond to is ~600ms
filter_length = int(filter_time/dt) # Filter is filter_length frames long
Expand All @@ -19,10 +19,11 @@ cweight = .5 # The weight of combination for the two filters


def make_noise(): # Generate gaussian noise for stimulus
return np.random.normal(0, 9, total_frames)
stimulus = make_noise()
return np.random.normal(0, 1, total_frames)

filter_index1 = 2 # Change filter type here
#stimulus = make_noise()

filter_index1 = 1 # Change filter type here
filter_index2 = 3


Expand All @@ -47,9 +48,9 @@ filtered1 = np.convolve(filter_kernel1, stimulus,
filtered2 = np.convolve(filter_kernel2, stimulus,
mode='full')[:-filter_length+1]

k = np.linspace(-30, 30, 1001)
nlt_index1 = 1
nlt_index2 = 5
k = np.linspace(-5, 5, 1001)
nlt_index1 = 6
nlt_index2 = 3


def nlt(k, nlt_index):
Expand All @@ -60,9 +61,11 @@ def nlt(k, nlt_index):
elif nlt_index == 3:
nlfunction = lambda x: -10/(1+np.exp(x-2))+10
elif nlt_index == 4:
nlfunction = lambda x: (-x*.4 if x < 0 else x)
nlfunction = lambda x: (-x*.8 if x < 0 else x)
elif nlt_index == 5:
nlfunction = lambda x: (0.02*x**2 if x < 0 else .04*x**2)
nlfunction = lambda x: (0.2*x**2 if x < 0 else .4*x**2)
elif nlt_index == 6:
nlfunction = lambda x: 3*x**2
else:
raise ValueError('Invalid non-linearity index')
return np.array([nlfunction(x) for x in k])
Expand All @@ -72,21 +75,22 @@ fire_rates2 = np.array(nlt(filtered2, nlt_index2))
fire_rates_sum = cweight*fire_rates1+(1-cweight)*fire_rates2
# Fire rates are combined with a linear weight

# Spikes
# %%Spikes
spikes = np.array([])
for i in fire_rates_sum:
spikes = np.append(spikes, np.random.poisson(i))
# Number of spikes for each frame

print('{} spikes generated.'.format(int(sum(spikes))))
#%%
# %%


def sta(spikes, stimulus, filter_length):
snippets = np.zeros(filter_length)
for i in range(filter_length, total_frames):
snippets = snippets+stimulus[i:i-filter_length:-1]*spikes[i]
# Snippets are inverted before being added
if spikes[i] != 0:
snippets = snippets+stimulus[i:i-filter_length:-1]*spikes[i]
# Snippets are inverted before being added
sta_unscaled = snippets/sum(spikes) # Normalize/scale the STA
sta_scaled = sta_unscaled/np.sqrt(sum(np.power(sta_unscaled, 2)))
return sta_scaled
Expand All @@ -95,7 +99,7 @@ recovered_kernel = sta(spikes, stimulus, filter_length)
filtered_recovery = np.convolve(recovered_kernel, stimulus,
mode='full')[:-filter_length+1]

#%% Variable bin size, log
# %% Variable bin size, log
log_bin_nr = 60
logbins = np.logspace(0, np.log(30)/np.log(10), log_bin_nr)
logbins = -logbins[::-1]+logbins
Expand All @@ -106,8 +110,8 @@ for i in range(log_bin_nr):
(np.average(spikes[np.where
(logbindices == i)])))

#%% Using mquantiles
bin_nr = 1000
# %% Using mquantiles
bin_nr = 100
quantiles = mquantiles(filtered_recovery,
np.linspace(0, 1, bin_nr, endpoint=False))
bindices = np.digitize(filtered_recovery, quantiles)
Expand Down
36 changes: 20 additions & 16 deletions plotLNP.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,28 @@
#plt.style.use('dark_background')
plt.style.use('default')
matplotlib.rcParams['grid.alpha'] = 0.1
rows=2
columns=2
fig=plt.figure(figsize=(12,8))
rows = 2
columns = 2
fig = plt.figure(figsize=(12, 8))

plt.subplot(rows,columns,1)
plt.plot(filter_kernel1,alpha=.2)
plt.subplot(rows, columns, 1)
plt.plot(filter_kernel1, alpha=.2)
#plt.title()

plt.subplot(rows,columns,1)
plt.plot(filter_kernel2,alpha=.2)
plt.subplot(rows, columns, 1)
plt.plot(filter_kernel2, alpha=.2)

plt.subplot(rows,columns,1)
plt.plot(cweight*filter_kernel1+(1-cweight)*filter_kernel2,alpha=.6)
plt.subplot(rows, columns, 1)
plt.plot(cweight*filter_kernel1+(1-cweight)*filter_kernel2, alpha=.6)

plt.subplot(rows,columns,1)
plt.plot(recovered_kernel,alpha=.6)
plt.subplot(rows, columns, 1)
plt.plot(recovered_kernel, alpha=.6)

plt.legend(['Filter 1','Filter 2','{}*Filter 1+{}*Filter 2'.format(cweight
,np.round(1-cweight,2)),'Spike triggered average (STA)'],
plt.legend(['Filter {}'.format(filter_index1),
'Filter {}'.format(filter_index2),
'{}*Filter {}+{}*Filter {}'.format(cweight, filter_index1,
np.round(1-cweight,2),filter_index2),
'Spike triggered average (STA)'],
fontsize='x-small')
plt.grid()
plt.title('Linear transformation')
Expand All @@ -48,9 +51,10 @@

plt.subplot(rows,columns,2)
plt.scatter(quantiles,spikecount_in_bins,s=6,alpha=.6)
plt.legend(['Non-linear transformation 1',
'Non-linear transformation 2',
'{}*NLT1+{}*NLT2'.format(cweight,np.round(1-cweight,2)),
plt.legend(['Non-linear transformation {}'.format(nlt_index1),
'Non-linear transformation {}'.format(nlt_index2),
'{}*NLT{}+{}*NLT{}'.format(cweight,nlt_index1,
np.round(1-cweight,2),nlt_index2),
'Recovered using logbins',
'Recovered using quantiles'],
fontsize='x-small')
Expand Down
34 changes: 26 additions & 8 deletions stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,48 @@
"""
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
execution_timer = datetime.now()

sta_temp = sta(spikes, stimulus, filter_length)


def stc(spikes, stimulus, filter_length, sta_temp):
covariance = np.matrix(np.zeros((filter_length, filter_length)))
covariance = np.zeros((filter_length, filter_length))
for i in range(filter_length, total_frames):
if spikes[i] != 0:
snippet = stimulus[i:i-filter_length:-1]
# Snippets are inverted before being added
snpta = np.matrix(snippet-sta_temp)
covariance = covariance+(snpta.T*snpta)*spikes[i]
return covariance/(sum(spikes)*filter_length-1)
snpta = np.array(snippet-sta_temp)[np.newaxis,:]
covariance = covariance+np.dot(snpta.T, snpta)*spikes[i]
return covariance/(sum(spikes)-1)

recovered_stc = stc(spikes, stimulus, filter_length,
sta(spikes, stimulus, filter_length))
runtime = str(datetime.now()-execution_timer).split('.')[0]
print('Duration: {}'.format(runtime))


w,v = np.linalg.eig(recovered_stc)
# %%
w, v = np.linalg.eig(recovered_stc)
# column v[:,i] is the eigenvector corresponding to the eigenvalue w[i]
sorted_eig = np.argsort(w)[::-1]
w = w[sorted_eig]
v = v[:, sorted_eig]

fig=plt.figure(figsize=(12, 4))

plt.subplot(1,2,1)
plt.plot(w, 'o', markersize=2)

plt.subplot(1,2,2)
plt.plot(v[:, 0])
plt.plot(v[:, 1])
plt.plot(recovered_kernel)
plt.legend(['0', '1', 'STA'], fontsize='x-small')


#plt.plot(v[:, -1])
#plt.plot(v[:, -2])
#plt.legend(['1', '2', '-1', '-2'], fontsize='x-small')
plt.show()
plt.plot(v[:,1],'b')
plt.plot(v[:,2],'g')

0 comments on commit ca815ea

Please sign in to comment.