Skip to content

Commit

Permalink
Add NLT recovery for STC and restructure plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
ycanerol committed May 19, 2017
1 parent 410c5ae commit 430835a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 28 deletions.
13 changes: 9 additions & 4 deletions LNP_model
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,13 @@ def log_nlt_recovery(spikes, filtered_recovery, bin_nr, k):
logbins = -logbins[::-1]+logbins
logbindices = np.digitize(filtered_recovery, logbins)
spikecount_in_logbins = np.array([])
for i in range(log_bin_nr):
for i in range(bin_nr):
spikecount_in_logbins = np.append(spikecount_in_logbins,
(np.average(spikes[np.where
(logbindices == i)]))/dt)
return logbins, spikecount_in_logbins

logbins,spikecount_in_logbins = log_nlt_recovery(spikes,
filtered_recovery, 60, k)

# %% Using mquantiles
def q_nlt_recovery(spikes,filtered_recovery, bin_nr):
quantiles = mquantiles(filtered_recovery,
Expand All @@ -126,7 +125,13 @@ def q_nlt_recovery(spikes,filtered_recovery, bin_nr):
(bindices == i)])/dt))
return quantiles,spikecount_in_bins

quantiles,spikecount_in_bins = q_nlt_recovery(spikes, filtered_recovery,100)
logbins_sta,spikecount_in_logbins_sta = log_nlt_recovery(spikes,
filtered_recovery, 60, k)

quantiles_sta,spikecount_in_bins_sta = q_nlt_recovery(spikes,
filtered_recovery,100)

runfile('/Users/ycan/Documents/official/gottingen/lab rotations/LR3 Gollisch/scripts/stc.py', wdir='/Users/ycan/Documents/python')

runfile('/Users/ycan/Documents/official/gottingen/lab rotations/LR3 Gollisch/scripts/plotLNP.py', wdir='/Users/ycan/Documents/python')

Expand Down
29 changes: 27 additions & 2 deletions plotLNP.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
plt.plot(k,cweight*nlt(k,nlt_index1)+(1-cweight)*nlt(k,nlt_index2),alpha=.6)

plt.subplot(rows,columns,2)
plt.scatter(logbins,spikecount_in_logbins,s=6,alpha=.6)
plt.scatter(logbins_sta,spikecount_in_logbins_sta,s=6,alpha=.6)

plt.subplot(rows,columns,2)
plt.scatter(quantiles,spikecount_in_bins,s=6,alpha=.6)
plt.scatter(quantiles_sta,spikecount_in_bins_sta,s=6,alpha=.6)
plt.legend(['Non-linear transformation {}'.format(nlt_index1),
'Non-linear transformation {}'.format(nlt_index2),
'{}*NLT{}+{}*NLT{}'.format(cweight,nlt_index1,
Expand All @@ -71,3 +71,28 @@
.format(np.round(np.max(t)),dt))
print('{} spikes generated.'.format(int(sum(spikes)))
)

# %% Plotting STC

fig = plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(w, 'o', markersize=2)
plt.xlabel('Eigenvalue index')
plt.ylabel('Variance')


eigen_legends = []

plt.subplot(1, 2, 2)
for i in interesting_eigen_indices:
plt.plot(v[:, i])
eigen_legends.append(str('Eigenvector '+str(i)))
plt.plot(recovered_kernel,':')
eigen_legends.append('STA')
plt.legend(eigen_legends, fontsize='x-small')
plt.title('Filters recovered by STC')
plt.xlabel('?')
plt.ylabel('?')
plt.show()

28 changes: 6 additions & 22 deletions stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,14 @@ def stc(spikes, stimulus, filter_length, sta_temp):
w = w[sorted_eig]
v = v[:, sorted_eig]

fig = plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(w, 'o', markersize=2)
plt.xlabel('Eigenvalue index')
plt.ylabel('Variance')

interesting_eigen_indices=np.where(np.abs(w-1)>.05)[0]
eigen_indices = [0, -1]
eigen_legends = []

plt.subplot(1, 2, 2)
for i in interesting_eigen_indices:
plt.plot(v[:, i])
eigen_legends.append(str('Eigenvector '+str(i)))
plt.plot(recovered_kernel,':')
eigen_legends.append('STA')
plt.legend(eigen_legends, fontsize='x-small')
plt.title('Filters recovered by STC')
plt.xlabel('?')
plt.ylabel('?')
logbins_stc1,spikecount_in_logbins_stc1 = log_nlt_recovery(spikes,
v[:,eigen_indices[0]], 60, k)
#quantiles_stc1,spikecount_in_bins_stc1 = q_nlt_recovery(spikes, filtered_recovery,100)
logbins_stc2,spikecount_in_logbins_stc2 = log_nlt_recovery(spikes,
v[:,eigen_indices[1]], 60, k)


#plt.plot(v[:, -1])
#plt.plot(v[:, -2])
#plt.legend(['1', '2', '-1', '-2'], fontsize='x-small')
plt.show()

0 comments on commit 430835a

Please sign in to comment.