Skip to content

Commit

Permalink
minor improvements and some commenting
Browse files Browse the repository at this point in the history
  • Loading branch information
epnev committed Jul 15, 2016
1 parent 070194a commit 41d92e6
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 134 deletions.
98 changes: 48 additions & 50 deletions cont_ca_sampler.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
% add the capability of dealing with missing numbers

% Continuous time sampler
% Y data (normalized in [0,1])
% Y data
% P intialization parameters (discrete time constant P.g required)
% params parameters structure
% params.g discrete time constant(s) (estimated if not provided)
Expand All @@ -27,50 +27,55 @@
% params.c1_lb lower bound for initial value (default 0)

% output struct SAMPLES
% spikes T x Nsamples matrix with spikes samples
% bp Nsamples x 1 vector with samples for spiking prior probability
% ss Nsamples x 1 cells with spike times for each sample
% ns Nsamples x 1 vector with number of spikes
% Am Nsamples x 1 vector with samples for spike amplitude
% ld Nsamples x 1 vector with samples for firing rate

% If marginalized sampler is used
% If marginalized sampler is used (params.marg = 1)
% Cb posterior mean and sd for baseline
% Cin posterior mean and sd for initial condition
% else
% Cb Nsamples x 1 vector with samples for baseline
% Cin Nsamples x 1 vector with samples for initial concentration
% sn Nsamples x 1 vector with samples for noise variance

% If gamma is updated
% If gamma is updated (params.upd_gam = 1)
% g Nsamples x p vector with the gamma updates

% Author: Eftychios A. Pnevmatikakis and Josh Merel

Y = Y(:);
T = length(Y);
isanY = ~isnan(Y);
isanY = ~isnan(Y); % Deal with possible missing entries
E = speye(T);
E = E(isanY,:);

% define default parameters
defparams.g = [];
defparams.sn = [];
defparams.b = [];
defparams.c1 = [];
defparams.Nsamples = 400;
defparams.B = 200;
defparams.marg = 0;
defparams.upd_gam = 1;
defparams.gam_step = 1;
defparams.A_lb = 0.1*range(Y);
defparams.b_lb = quantile(Y,0.01);
defparams.c1_lb = 0;

defparams.std_move = 3;
defparams.add_move = ceil(T/100);
defparams.init = [];
defparams.f = 1;
defparams.p = 1;
defparams.defg = [0.6,0.95];
defparams.TauStd = [.1,1];
defparams.g = []; % initializer for time constants
defparams.sn = []; % initializer for noise std
defparams.b = []; % initializer for baseline concentration
defparams.c1 = []; % initializer for initial concentration
defparams.c = []; % initializer for calcium concentration
defparams.sp = []; % initializer for spiking signal
defparams.Nsamples = 400; % number of samples after burn in period
defparams.B = 200; % length of burn in period
defparams.marg = 0; % flag to marginalize out baseline and initial concentration
defparams.upd_gam = 1; % flag for updating time constants
defparams.gam_step = 1; % flag for how often to update time constants
defparams.A_lb = 0.1*range(Y); % lower bound for spike amplitude
defparams.b_lb = quantile(Y,0.01); % lower bound for baseline
defparams.c1_lb = 0; % lower bound for initial concentration

defparams.std_move = 3; % standard deviation of spike move kernel
defparams.add_move = ceil(T/100); % number of add moves
defparams.init = []; % sampler initializer
defparams.f = 1; % imaging rate (irrelevant)
defparams.p = 1; % order of AR process (use p = 1 or p = 2)
defparams.defg = [0.6,0.95]; % default time constant roots
defparams.TauStd = [.1,1]; % Standard deviation for time constant proposal
defparams.prec = 1e-2; % Precision parameter when adding new spikes
defparams.con_lam = true; % Flag for constant firing across time
defparams.print_flag = 0;

if nargin < 2
Expand All @@ -80,6 +85,8 @@
if ~isfield(params,'sn'); params.sn = defparams.sn; end
if ~isfield(params,'b'); params.b = defparams.b; end
if ~isfield(params,'c1'); params.c1 = defparams.c1; end
if ~isfield(params,'c'); params.c = defparams.c; end
if ~isfield(params,'sp'); params.sp = defparams.sp; end
if ~isfield(params,'Nsamples'); params.Nsamples = defparams.Nsamples; end
if ~isfield(params,'B'); params.B = defparams.B; end
if ~isfield(params,'marg'); params.marg = defparams.marg; end
Expand All @@ -95,7 +102,9 @@
if ~isfield(params,'A_lb'); params.A_lb = defparams.A_lb; end
if ~isfield(params,'b_lb'); params.b_lb = defparams.b_lb; end
if ~isfield(params,'c1_lb'); params.c1_lb = defparams.c1_lb; end
if ~isfield(params,'print_flag'); params.print_flag = defparams.print_flag; end
if ~isfield(params,'prec'); params.prec = defparams.prec; end
if ~isfield(params,'con_lam'); params.con_lam = defparams.con_lam; end
if ~isfield(params,'print_flag'); params.print_flag = defparams.print_flag; end
end

Dt = 1; % length of time bin
Expand All @@ -113,14 +122,13 @@
end

if isempty(params.init)
fprintf('Initializing using noise constrained FOOPSI... ');
params.init = get_initial_sample(Y,params);
fprintf('done. \n');
params.init = get_initial_sample(Y,params);
end
SAM = params.init;
g = SAM.g(:)';

g = SAM.g(:)'; % check initial time constants, if not reasonable set to default values
if g == 0
gr = [0.9,0.1];
gr = params.defg;
pl = poly(gr);
g = -pl(2:end);
p = 2;
Expand All @@ -143,11 +151,8 @@
end
G2 = spdiags(ones(T,1)*[-max(gr),1],[-1:0],T,T);


sg = SAM.sg;


SAM = params.init;
spiketimes_ = SAM.spiketimes_;
lam_ = SAM.lam_;
A_ = SAM.A_*diff(gr);
Expand All @@ -159,14 +164,9 @@
s_1(ceil(spiketimes_/Dt)) = exp((spiketimes_ - Dt*ceil(spiketimes_/Dt))/tau(1));
s_2(ceil(spiketimes_/Dt)) = exp((spiketimes_ - Dt*ceil(spiketimes_/Dt))/tau(2));

if ~isfield(params,'prec') % FN % (from Eftychios: prec specifies to what extent you want to discard the long slowly decaying tales of the ca response. Try setting it e.g., to 5e-2 instead of 1e-2 to speed things up.)
prec = 1e-2; % precision
else
prec = params.prec; %5e-2; % FN
end

prec = params.prec;

ef_d = exp(-(0:T)/tau(2));
ef_d = exp(-(0:T)/tau(2)); % construct transient exponentials
if p == 1
h_max = 1; % max value of transient
ef_h = [0,0];
Expand Down Expand Up @@ -206,7 +206,7 @@
SG = zeros(N,1);
end

Sp = .1*range(Y)*eye(3); % prior covariance
Sp = .1*range(Y)*eye(3); % prior covariance for [A,Cb,Cin]
Ld = inv(Sp);
lb = [params.A_lb/h_max*diff(gr),params.b_lb,params.c1_lb]'; % lower bound for [A,Cb,Cin]

Expand All @@ -230,13 +230,13 @@
%%%%%%%%%%%%%%%%%%%%%%%%%%%


for i = 1:N
for i = 1:N
if gam_flag
Gam(i,:) = tau;
end
sg_ = sg;
rate = @(t) lambda_rate(t,lam_);
[spiketimes, ~] = get_next_spikes(spiketimes_(:)',A_*Gs',Ym',ef,tau,sg_^2, rate, std_move, add_move, Dt, A_);
[spiketimes, ~] = get_next_spikes(spiketimes_(:)',A_*Gs',Ym',ef,tau,sg_^2, rate, std_move, add_move, Dt, A_, params.con_lam);
spiketimes(spiketimes<0) = -spiketimes(spiketimes<0);
spiketimes(spiketimes>T*Dt) = 2*T*Dt - spiketimes(spiketimes>T*Dt);
spiketimes_ = spiketimes;
Expand Down Expand Up @@ -293,8 +293,8 @@
end
Am(i) = A_;
if i > B
mub = mub + mu_post(2+(0:p));
Sigb = Sigb + L(2+(0:p),2+(0:p));
mub = mub + mu_post(1+(1:p));
Sigb = Sigb + L(1+(1:p),1+(1:p));
end
end
if gam_flag
Expand All @@ -316,7 +316,6 @@
logC_ = -norm(E*(Y(:)-A_*Gs-b_-C_in*ge))^2;
%accept or reject
prior_ratio = 1;
% prior_ratio = gampdf(tau_(2),12,1)/gampdf(tau(2),12,1);
ratio = exp((logC_-logC)/(2*sg^2))*prior_ratio;
if rand < ratio %accept
tau = tau_;
Expand Down Expand Up @@ -351,7 +350,6 @@

%accept or reject
prior_ratio = 1;
% prior_ratio = gampdf(tau_(2),12,1)/gampdf(tau(2),12,1);
ratio = exp((1./(2*sg^2)).*(logC_-logC))*prior_ratio;
if rand<ratio %accept
tau = tau_;
Expand Down Expand Up @@ -397,7 +395,7 @@

if marg_flag
SAMPLES.Cb = [mub(1),sqrt(Sigb(1,1))];
SAMPLES.Cin = [mub(1+(1:p)),sqrt(diag(Sigb(1+(1:p),1+(1:p))))];
SAMPLES.Cin = [mub(2),sqrt(Sigb(2,2))]; %[mub(1+(1:p)),sqrt(diag(Sigb(1+(1:p),1+(1:p))))];
else
SAMPLES.Cb = Cb(B+1:N);
SAMPLES.Cin = Cin(B+1:N,:);
Expand Down
11 changes: 3 additions & 8 deletions plot_continuous_samples.m
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,9 @@ function plot_continuous_samples(SAMPLES,Y)
subplot(4,4,15); plot(xx,normpdf(xx,SAMPLES.Cb(1),SAMPLES.Cb(2)));
set(gca,'XLim',[xx(1),xx(end)])
title('Marg. post. of baseline','fontweight','bold','fontsize',14)
if p == 1
xx = SAMPLES.Cin(1) + linspace(-4*SAMPLES.Cin(2),4*SAMPLES.Cin(2));
subplot(4,4,16); plot(xx,normpdf(xx,SAMPLES.Cin(1),SAMPLES.Cin(2)));
else
xx = linspace(min(SAMPLES.Cin(1),SAMPLES.Cin(2))-4*min(SAMPLES.Cin(3),SAMPLES.Cin(4)),max(SAMPLES.Cin(1),SAMPLES.Cin(2)) + 4*max(SAMPLES.Cin(3),SAMPLES.Cin(4)));
subplot(4,4,16); plot(xx,normpdf(xx,SAMPLES.Cin(1),SAMPLES.Cin(3))); hold all;
plot(xx,normpdf(xx,SAMPLES.Cin(2),SAMPLES.Cin(4)));
end

xx = SAMPLES.Cin(1) + linspace(-4*SAMPLES.Cin(2),4*SAMPLES.Cin(2));
subplot(4,4,16); plot(xx,normpdf(xx,SAMPLES.Cin(1),SAMPLES.Cin(2)));
set(gca,'XLim',[xx(1),xx(end)])
title('Marg. post. of initial con','fontweight','bold','fontsize',14)
else
Expand Down
64 changes: 0 additions & 64 deletions plot_continuous_samples_l.m

This file was deleted.

14 changes: 10 additions & 4 deletions sampling_demo_ar2.m
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@
title('Foopsi Spikes','FontWeight','bold','Fontsize',14); xlabel('Timestep','FontWeight','bold','Fontsize',16);
legend('Foopsi Spikes','Ground Truth');
drawnow;
%% MCMC
%% MCMC

params.p = 2;
params.g = g2;

SAMPLES = cont_ca_sampler(y,params); %% MCMC
plot_continuous_samples(SAMPLES,y(:));
params.sp = spikes_foopsi; % pass results of foopsi for initialization (if not, they are computed)
params.c = ca_foopsi;
params.b = cb;
params.c1 = c1;
params.sn = sg;
params.marg = 0;
SAMPLES2 = cont_ca_sampler(y,params); %% MCMC
plot_continuous_samples(SAMPLES2,y(:));
13 changes: 12 additions & 1 deletion utilities/get_initial_sample.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@


if isfield(params,'p'); options.p = params.p; else options.p = 1; end
[c,b,c1,g,sn,sp] = constrained_foopsi(Y,params.b,params.c1,params.g,params.sn,options);
if isempty(params.c) || isempty(params.b) || isempty(params.c1) || isempty(params.g) || isempty(params.sn) || isempty(params.sp)
fprintf('Initializing using noise constrained FOOPSI... ');
[c,b,c1,g,sn,sp] = constrained_foopsi(Y,params.b,params.c1,params.g,params.sn,options);
fprintf('done. \n');
else
c = params.c;
b = params.b;
c1 = params.c1;
g = params.g;
sn = params.sn;
sp = params.sp;
end

Dt = 1;
T = length(Y);
Expand Down
11 changes: 4 additions & 7 deletions utilities/get_next_spikes.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [samples, ci] = get_next_spikes(curr_spikes,curr_calcium,calciumSignal,ef,tau,calciumNoiseVar, lam, proposalVar, add_move, Dt, A)
function [samples, ci] = get_next_spikes(curr_spikes,curr_calcium,calciumSignal,ef,tau,calciumNoiseVar, lam, proposalVar, add_move, Dt, A, con_lam)

%addMoves, dropMoves, and timeMoves give acceptance probabilities for each subclass of move
%the samples will be a cell array of lists of spike times - the spike times won't be sorted but this shouldn't be a problem.
Expand Down Expand Up @@ -35,17 +35,13 @@
% 4) Gibbs for time shifts with likelihood proposal add/drop

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%





%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% loop over sweeps to generate samples
addMoves = [0 0]; %first elem is number successful, second is number total
dropMoves = [0 0];
timeMoves = [0 0];
time_move = 0;
time_add = 0;
for i = 1:nsweeps

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Expand All @@ -68,7 +64,8 @@
[si_, ci_, logC_] = replaceSpike(si,ci,logC,ef,tau,calciumSignal,tmpi,ni,tmpi_,Dt,A);

%accept or reject
ratio = exp((logC_-logC)/(2*calciumNoiseVar)*lam(tmpi)/lam(tmpi_));
ratio = exp((logC_-logC)/(2*calciumNoiseVar));
if ~con_lam; ratio = ratio*lam(tmpi)/lam(tmpi_); end
if ratio>1 %accept
si = si_;
ci = ci_;
Expand Down

0 comments on commit 41d92e6

Please sign in to comment.