Skip to content

Commit

Permalink
fix stupid ** vs * typo
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Feb 16, 2024
1 parent 601d62e commit ecb48fd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
7 changes: 4 additions & 3 deletions vbjax/neural_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def dcm_dfun(x, u, p: DCMTheta):

DodyTheta = collections.namedtuple(
typename='DodyTheta',
field_names='a, b, c, ga, gg, Eta, Delta, Iext, Ea, Eg, Sja, Sjg,'
'tauSa, tauSg, alpha, beta, ud, k, Vmax, Km, Bd, Ad, tau_Dp')
field_names='a, b, c, ga, gg, Eta, Delta, Iext, Ea, Eg, Sja, Sjg, tauSa, tauSg, alpha, beta, ud, k, Vmax, Km, Bd, Ad, tau_Dp')

DodyState = collections.namedtuple(
typename='DodyState',
Expand All @@ -172,10 +171,12 @@ def dody_dfun(y: DodyState, cy: DodyCouplings, p: DodyTheta):
# c_dopa = Cd @ r
c_inh, c_exc, c_dopa = cy
a, b, c, ga, gg, Eta, Delta, Iext, Ea, Eg, Sja, Sjg, tauSa, tauSg, alpha, beta, ud, k, Vmax, Km, Bd, Ad, tau_Dp = p

dr = 2. * a * r * V + b * r - ga * Sa * r - gg * Sg * r + (a * Delta) / np.pi
dV = a * V**2 + b * V + c + Eta - (np.pi*2 * r**2) / a + (Ad * Dp + Bd) * ga * Sa * (Ea - V) + gg * Sg * (Eg - V) + Iext - u
dV = a * V**2 + b * V + c + Eta - (np.pi**2 * r**2) / a + (Ad * Dp + Bd) * ga * Sa * (Ea - V) + gg * Sg * (Eg - V) + Iext - u
du = alpha * (beta * V - u) + ud * r
dSa = -Sa / tauSa + Sja * c_exc
dSg = -Sg / tauSg + Sjg * c_inh
dDp = (k * c_dopa - Vmax * Dp / (Km + Dp)) / tau_Dp

return DodyState(dr, dV, du, dSa, dSg, dDp)
50 changes: 35 additions & 15 deletions vbjax/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import vbjax as vb


def test_dody():
def true_dopa():

a=0.04
b=5.
Expand Down Expand Up @@ -71,11 +71,6 @@ def aQIFdopa(y,t,params,coupling_inhibitor,coupling_excitator,coupling_dopamine)

def network(y, t, ckk, params):
r = y[0*n_nodes : 1*n_nodes]
V = y[1*n_nodes : 2*n_nodes]
u = y[2*n_nodes : 3*n_nodes]
Sa = y[3*n_nodes : 4*n_nodes]
Sg = y[4*n_nodes : 5*n_nodes]
Dp = y[5*n_nodes : 6*n_nodes]

aff_inhibitor = conn_inhibitor @ r * ckk
aff_excitator = conn_excitator @ r * ckk
Expand All @@ -88,24 +83,28 @@ def heun_SDE(network,y0,t0,t_max,dt,params,ckk,sigma):
num_steps = int((t_max - t0) / dt)
y_all = np.empty((num_steps, len(y0)))
t_all = np.empty((num_steps, ))
stochastic_matrix = sigma * np.random.normal(0, 1, (len(y0),num_steps))*np.sqrt(dt)
stochastic_matrix = np.random.normal(0, 1, (len(y0),num_steps))
t=t0; i=0
t_all[i] = t0
y_all[i, :] = y0
y=y0
dws = []
for step in range(num_steps):
dw = stochastic_matrix[:,step]
dws.append(dw)
ye = y + dt * network(y, t, ckk,params) + dw
y = y + 0.5 * dt * (network(y, t, ckk,params) + network(ye, t + dt, ckk,params)) + dw
dw = stochastic_matrix[:,step]*sigma * np.sqrt(dt)
dy1 = network(y, t, ckk,params)
ye = y + dt * dy1 + dw
y = y + 0.5 * dt * (dy1 + network(ye, t + dt, ckk,params)) + dw
t=t+dt
t_all[i]=t
y_all[i,:]=y
i+=1
return y_all, t_all, np.array(dws)
return y_all, t_all, stochastic_matrix.T

y1, t1, dw = heun_SDE(network,y0,t0,tf,dt,params,ckk,sigma)
return y1, t1, dw, ckk, params, conn_inhibitor, conn_excitator, conn_dopamine, n_nodes, r0, V0, u0, Sa0, Sg0, Dp0, network, dt, sigma

def test_dody():

y1, t1, dw, ckk, params, conn_inhibitor, conn_excitator, conn_dopamine, n_nodes, r0, V0, u0, Sa0, Sg0, Dp0, network, dt, sigma = true_dopa()

# now in vbjax
def net(y, p):
Expand All @@ -123,7 +122,28 @@ def net(y, p):
j_dw = vb.DodyState(*jp.array(dw).reshape((-1, 6, n_nodes)).transpose(1,0,2))
j_y2: vb.DodyState = loop(j_y0, j_dw, (j_Ci, j_Ce, j_Cd, ckk, j_params))

# compare derivatives
for i in range(t1.size):
dy1 = network(y1[i], t1[i], ckk, params).reshape((6, -1))
dy2 = net(y1[i].reshape((6,-1)), (j_Ci, j_Ce, j_Cd, ckk, j_params))
for j in range(6):
# print(i, j)
np.testing.assert_allclose(dy1[j], dy2[j], atol=1e-2, rtol=0.1)

# compare trajectories
y1_ = y1.reshape((-1, 6, n_nodes))
for i in range(6):
np.testing.assert_allclose(y1_[:,i], j_y2[i], atol=1e-2, rtol=0.1)
if False:
# do plots
import matplotlib.pyplot as pl
for i in range(6):
pl.subplot(3, 2, i + 1)
pl.plot(t1, y1_[:,i], 'k', alpha=0.2)
pl.plot(t1, j_y2[i], 'r', alpha=0.2)
pl.grid(1)
np.testing.assert_allclose(y1_[:,i], j_y2[i], atol=1e-2, rtol=0.1)
pl.savefig('dody.png', dpi=300)
else:
# don't bother plots just assert all close each var
for i in range(6):
np.testing.assert_allclose(y1_[:,i], j_y2[i], atol=1e-2, rtol=0.1)

0 comments on commit ecb48fd

Please sign in to comment.