Skip to content

Commit

Permalink
Complete STC NLT recovery, plot recovered NLTs
Browse files Browse the repository at this point in the history
  • Loading branch information
ycanerol committed May 19, 2017
1 parent 430835a commit 8930b11
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 31 deletions.
32 changes: 15 additions & 17 deletions LNP_model
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Created on Tue May 9 18:11:51 2017
@author: ycan
"""
# %%
import numpy as np
from scipy.stats.mstats import mquantiles
from datetime import datetime
Expand Down Expand Up @@ -53,9 +52,9 @@ filtered2 = np.convolve(filter_kernel2, stimulus,

k = np.linspace(-5, 5, 1001)
nlt_index1 = 1
nlt_index2 = 2
nlt_index2 = 4


# %%
def nlt(k, nlt_index):
if nlt_index == 1:
nlfunction = lambda x: (0 if x < 0 else 4.2*x)
Expand All @@ -72,17 +71,15 @@ def nlt(k, nlt_index):
else:
raise ValueError('Invalid non-linearity index')
return np.array([nlfunction(x) for x in k])
# %%

fire_rates1 = np.array(nlt(filtered1, nlt_index1))
fire_rates2 = np.array(nlt(filtered2, nlt_index2))
fire_rates_sum = cweight*fire_rates1+(1-cweight)*fire_rates2
# Fire rates are combined with a linear weight

# %% Spikes
spikes = np.array(np.random.poisson(fire_rates_sum*dt))

print('{} spikes generated.'.format(int(sum(spikes))))
# %%


def sta(spikes, stimulus, filter_length):
Expand All @@ -94,12 +91,12 @@ def sta(spikes, stimulus, filter_length):
sta_unscaled = snippets/sum(spikes) # Normalize/scale the STA
sta_scaled = sta_unscaled/np.sqrt(sum(np.power(sta_unscaled, 2)))
return sta_scaled, sta_unscaled
recovered_kernel = sta(spikes, stimulus, filter_length)[0] # Use scaled STA
recovered_kernel = sta(spikes, stimulus, filter_length)[0] # Use scaled STA

filtered_recovery = np.convolve(recovered_kernel, stimulus,
mode='full')[:-filter_length+1]

# %% Variable bin size, log
# Variable bin size, log
def log_nlt_recovery(spikes, filtered_recovery, bin_nr, k):
logbins = np.logspace(0, np.log(max(k))/np.log(10), bin_nr)
logbins = -logbins[::-1]+logbins
Expand All @@ -111,9 +108,8 @@ def log_nlt_recovery(spikes, filtered_recovery, bin_nr, k):
(logbindices == i)]))/dt)
return logbins, spikecount_in_logbins


# %% Using mquantiles
def q_nlt_recovery(spikes,filtered_recovery, bin_nr):
# Using mquantiles
def q_nlt_recovery(spikes, filtered_recovery, bin_nr, k=0):
quantiles = mquantiles(filtered_recovery,
np.linspace(0, 1, bin_nr, endpoint=False))
bindices = np.digitize(filtered_recovery, quantiles)
Expand All @@ -123,17 +119,19 @@ def q_nlt_recovery(spikes,filtered_recovery, bin_nr):
spikecount_in_bins = np.append(spikecount_in_bins,
(np.average(spikes[np.where
(bindices == i)])/dt))
return quantiles,spikecount_in_bins
return quantiles, spikecount_in_bins

logbins_sta,spikecount_in_logbins_sta = log_nlt_recovery(spikes,
filtered_recovery, 60, k)
logbins_sta, spikecount_in_logbins_sta = log_nlt_recovery(
spikes,
filtered_recovery,
60, k)

quantiles_sta,spikecount_in_bins_sta = q_nlt_recovery(spikes,
filtered_recovery,100)
quantiles_sta, spikecount_in_bins_sta = q_nlt_recovery(spikes,
filtered_recovery, 100)

runfile('/Users/ycan/Documents/official/gottingen/lab rotations/LR3 Gollisch/scripts/stc.py', wdir='/Users/ycan/Documents/python')

runfile('/Users/ycan/Documents/official/gottingen/lab rotations/LR3 Gollisch/scripts/plotLNP.py', wdir='/Users/ycan/Documents/python')

runtime = str(datetime.now()-execution_timer).split('.')[0]
print('Duration: {}'.format(runtime))
print('Duration: {}'.format(runtime))
22 changes: 14 additions & 8 deletions plotLNP.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@
'Non-linear transformation {}'.format(nlt_index2),
'{}*NLT{}+{}*NLT{}'.format(cweight,nlt_index1,
np.round(1-cweight,2),nlt_index2),
'Recovered using logbins',
'Recovered using quantiles'],
'Recovered using logbins (STA)',
'Recovered using quantiles (STA)'],
fontsize='x-small')
plt.title('Non-linear transformation')
plt.xlabel('?')
plt.ylabel('?')
plt.scatter(logbins_stc1,spikecount_in_logbins_stc1)

plt.show()
print('{} seconds were simulated with {} s time steps.'
Expand All @@ -74,18 +75,16 @@

# %% Plotting STC

fig = plt.figure(figsize=(12, 4))
fig = plt.figure(figsize=(8, 10))

plt.subplot(1, 2, 1)
plt.subplot(2, 2, 1)
plt.plot(w, 'o', markersize=2)
plt.xlabel('Eigenvalue index')
plt.ylabel('Variance')


plt.subplot(2, 2, 2)
eigen_legends = []

plt.subplot(1, 2, 2)
for i in interesting_eigen_indices:
for i in eigen_indices:
plt.plot(v[:, i])
eigen_legends.append(str('Eigenvector '+str(i)))
plt.plot(recovered_kernel,':')
Expand All @@ -94,5 +93,12 @@
plt.title('Filters recovered by STC')
plt.xlabel('?')
plt.ylabel('?')


plt.subplot(2,1,2)
plt.scatter(logbins_stc1,spikecount_in_logbins_stc1)
plt.scatter(logbins_stc2,spikecount_in_logbins_stc2)

plt.title('Non-linearities recovered by STC')
plt.show()

20 changes: 14 additions & 6 deletions stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,22 @@ def stc(spikes, stimulus, filter_length, sta_temp):
w = w[sorted_eig]
v = v[:, sorted_eig]

interesting_eigen_indices=np.where(np.abs(w-1)>.05)[0]
eigen_indices = [0, -1]
eigen_indices=np.where(np.abs(w-1)>.05)[0]
manual_eigen_indices = [0, -1]

logbins_stc1,spikecount_in_logbins_stc1 = log_nlt_recovery(spikes,
v[:,eigen_indices[0]], 60, k)
filtered_recovery_stc1 = np.convolve(v[:, eigen_indices[0]], stimulus,
mode='full')[:-filter_length+1]

filtered_recovery_stc2 = np.convolve(v[:, eigen_indices[0]], stimulus,
mode='full')[:-filter_length+1]

logbins_stc1, spikecount_in_logbins_stc1 = log_nlt_recovery(spikes,
filtered_recovery_stc1,
60, k)
#quantiles_stc1,spikecount_in_bins_stc1 = q_nlt_recovery(spikes, filtered_recovery,100)
logbins_stc2,spikecount_in_logbins_stc2 = log_nlt_recovery(spikes,
v[:,eigen_indices[1]], 60, k)
logbins_stc2, spikecount_in_logbins_stc2 = log_nlt_recovery(spikes,
filtered_recovery_stc2,
60, k)



0 comments on commit 8930b11

Please sign in to comment.