From ca815ea2779f9b01afe819227be806536ac2f6fe Mon Sep 17 00:00:00 2001 From: ycanerol Date: Wed, 17 May 2017 16:51:55 +0200 Subject: [PATCH] - Change STC calculation from np.matrix to np.array , use np.dot - Change plotLNP to match PEP8 - Decrease time step to make calculations faster --- LNP_model | 40 ++++++++++++++++++++++------------------ plotLNP.py | 36 ++++++++++++++++++++---------------- stc.py | 34 ++++++++++++++++++++++++++-------- 3 files changed, 68 insertions(+), 42 deletions(-) diff --git a/LNP_model b/LNP_model index bd2aa6d..d9a721d 100644 --- a/LNP_model +++ b/LNP_model @@ -5,12 +5,12 @@ Created on Tue May 9 18:11:51 2017 @author: ycan """ -#%% +# %% import numpy as np from scipy.stats.mstats import mquantiles -total_frames = 10000 -dt = 0.001 # Time step +total_frames = 100000 +dt = 0.01 # Time step t = np.arange(0, total_frames*dt, dt) # Time vector filter_time = .6 # The longest feature RGCs respond to is ~600ms filter_length = int(filter_time/dt) # Filter is filter_length frames long @@ -19,10 +19,11 @@ cweight = .5 # The weight of combination for the two filters def make_noise(): # Generate gaussian noise for stimulus - return np.random.normal(0, 9, total_frames) -stimulus = make_noise() + return np.random.normal(0, 1, total_frames) -filter_index1 = 2 # Change filter type here +#stimulus = make_noise() + +filter_index1 = 1 # Change filter type here filter_index2 = 3 @@ -47,9 +48,9 @@ filtered1 = np.convolve(filter_kernel1, stimulus, filtered2 = np.convolve(filter_kernel2, stimulus, mode='full')[:-filter_length+1] -k = np.linspace(-30, 30, 1001) -nlt_index1 = 1 -nlt_index2 = 5 +k = np.linspace(-5, 5, 1001) +nlt_index1 = 6 +nlt_index2 = 3 def nlt(k, nlt_index): @@ -60,9 +61,11 @@ def nlt(k, nlt_index): elif nlt_index == 3: nlfunction = lambda x: -10/(1+np.exp(x-2))+10 elif nlt_index == 4: - nlfunction = lambda x: (-x*.4 if x < 0 else x) + nlfunction = lambda x: (-x*.8 if x < 0 else x) elif nlt_index == 5: - nlfunction = lambda x: (0.02*x**2 if x < 0 else .04*x**2) + nlfunction = lambda x: (0.2*x**2 if x < 0 else .4*x**2) + elif nlt_index == 6: + nlfunction = lambda x: 3*x**2 else: raise ValueError('Invalid non-linearity index') return np.array([nlfunction(x) for x in k]) @@ -72,21 +75,22 @@ fire_rates2 = np.array(nlt(filtered2, nlt_index2)) fire_rates_sum = cweight*fire_rates1+(1-cweight)*fire_rates2 # Fire rates are combined with a linear weight -# Spikes +# %%Spikes spikes = np.array([]) for i in fire_rates_sum: spikes = np.append(spikes, np.random.poisson(i)) # Number of spikes for each frame print('{} spikes generated.'.format(int(sum(spikes)))) -#%% +# %% def sta(spikes, stimulus, filter_length): snippets = np.zeros(filter_length) for i in range(filter_length, total_frames): - snippets = snippets+stimulus[i:i-filter_length:-1]*spikes[i] - # Snippets are inverted before being added + if spikes[i] != 0: + snippets = snippets+stimulus[i:i-filter_length:-1]*spikes[i] + # Snippets are inverted before being added sta_unscaled = snippets/sum(spikes) # Normalize/scale the STA sta_scaled = sta_unscaled/np.sqrt(sum(np.power(sta_unscaled, 2))) return sta_scaled @@ -95,7 +99,7 @@ recovered_kernel = sta(spikes, stimulus, filter_length) filtered_recovery = np.convolve(recovered_kernel, stimulus, mode='full')[:-filter_length+1] -#%% Variable bin size, log +# %% Variable bin size, log log_bin_nr = 60 logbins = np.logspace(0, np.log(30)/np.log(10), log_bin_nr) logbins = -logbins[::-1]+logbins @@ -106,8 +110,8 @@ for i in range(log_bin_nr): (np.average(spikes[np.where (logbindices == i)]))) -#%% Using mquantiles -bin_nr = 1000 +# %% Using mquantiles +bin_nr = 100 quantiles = mquantiles(filtered_recovery, np.linspace(0, 1, bin_nr, endpoint=False)) bindices = np.digitize(filtered_recovery, quantiles) diff --git a/plotLNP.py b/plotLNP.py index e0e0622..bfac7cc 100644 --- a/plotLNP.py +++ b/plotLNP.py @@ -10,25 +10,28 @@ #plt.style.use('dark_background') plt.style.use('default') matplotlib.rcParams['grid.alpha'] = 0.1 -rows=2 -columns=2 -fig=plt.figure(figsize=(12,8)) +rows = 2 +columns = 2 +fig = plt.figure(figsize=(12, 8)) -plt.subplot(rows,columns,1) -plt.plot(filter_kernel1,alpha=.2) +plt.subplot(rows, columns, 1) +plt.plot(filter_kernel1, alpha=.2) #plt.title() -plt.subplot(rows,columns,1) -plt.plot(filter_kernel2,alpha=.2) +plt.subplot(rows, columns, 1) +plt.plot(filter_kernel2, alpha=.2) -plt.subplot(rows,columns,1) -plt.plot(cweight*filter_kernel1+(1-cweight)*filter_kernel2,alpha=.6) +plt.subplot(rows, columns, 1) +plt.plot(cweight*filter_kernel1+(1-cweight)*filter_kernel2, alpha=.6) -plt.subplot(rows,columns,1) -plt.plot(recovered_kernel,alpha=.6) +plt.subplot(rows, columns, 1) +plt.plot(recovered_kernel, alpha=.6) -plt.legend(['Filter 1','Filter 2','{}*Filter 1+{}*Filter 2'.format(cweight - ,np.round(1-cweight,2)),'Spike triggered average (STA)'], +plt.legend(['Filter {}'.format(filter_index1), + 'Filter {}'.format(filter_index2), + '{}*Filter {}+{}*Filter {}'.format(cweight, filter_index1, + np.round(1-cweight,2),filter_index2), + 'Spike triggered average (STA)'], fontsize='x-small') plt.grid() plt.title('Linear transformation') @@ -48,9 +51,10 @@ plt.subplot(rows,columns,2) plt.scatter(quantiles,spikecount_in_bins,s=6,alpha=.6) -plt.legend(['Non-linear transformation 1', - 'Non-linear transformation 2', - '{}*NLT1+{}*NLT2'.format(cweight,np.round(1-cweight,2)), +plt.legend(['Non-linear transformation {}'.format(nlt_index1), + 'Non-linear transformation {}'.format(nlt_index2), + '{}*NLT{}+{}*NLT{}'.format(cweight,nlt_index1, + np.round(1-cweight,2),nlt_index2), 'Recovered using logbins', 'Recovered using quantiles'], fontsize='x-small') diff --git a/stc.py b/stc.py index 91b13ed..6a16e71 100644 --- a/stc.py +++ b/stc.py @@ -9,30 +9,48 @@ """ from datetime import datetime import numpy as np +import matplotlib.pyplot as plt execution_timer = datetime.now() sta_temp = sta(spikes, stimulus, filter_length) def stc(spikes, stimulus, filter_length, sta_temp): - covariance = np.matrix(np.zeros((filter_length, filter_length))) + covariance = np.zeros((filter_length, filter_length)) for i in range(filter_length, total_frames): if spikes[i] != 0: snippet = stimulus[i:i-filter_length:-1] # Snippets are inverted before being added - snpta = np.matrix(snippet-sta_temp) - covariance = covariance+(snpta.T*snpta)*spikes[i] - return covariance/(sum(spikes)*filter_length-1) + snpta = np.array(snippet-sta_temp)[np.newaxis,:] + covariance = covariance+np.dot(snpta.T, snpta)*spikes[i] + return covariance/(sum(spikes)-1) recovered_stc = stc(spikes, stimulus, filter_length, sta(spikes, stimulus, filter_length)) runtime = str(datetime.now()-execution_timer).split('.')[0] print('Duration: {}'.format(runtime)) - -w,v = np.linalg.eig(recovered_stc) +# %% +w, v = np.linalg.eig(recovered_stc) # column v[:,i] is the eigenvector corresponding to the eigenvalue w[i] +sorted_eig = np.argsort(w)[::-1] +w = w[sorted_eig] +v = v[:, sorted_eig] + +fig=plt.figure(figsize=(12, 4)) + +plt.subplot(1,2,1) plt.plot(w, 'o', markersize=2) + +plt.subplot(1,2,2) +plt.plot(v[:, 0]) +plt.plot(v[:, 1]) +plt.plot(recovered_kernel) +plt.legend(['0', '1', 'STA'], fontsize='x-small') + + +#plt.plot(v[:, -1]) +#plt.plot(v[:, -2]) +#plt.legend(['1', '2', '-1', '-2'], fontsize='x-small') plt.show() -plt.plot(v[:,1],'b') -plt.plot(v[:,2],'g') +