Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jax' into jax
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed Nov 1, 2024
2 parents a6f56ad + 288ca67 commit 8628d03
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 4 deletions.
59 changes: 55 additions & 4 deletions neurolib/models/jax/wc/timeIntegration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def timeIntegration_args(params):

# ------------------------------------------------------------------------

integration_method = params['integration_method']

return (
startind,
t,
Expand Down Expand Up @@ -134,6 +136,7 @@ def timeIntegration_args(params):
tau_ou,
sigma_ou,
key,
integration_method
)


Expand Down Expand Up @@ -170,6 +173,7 @@ def timeIntegration_elementwise(
tau_ou,
sigma_ou,
key,
integration_method
):

update_step = get_update_step(
Expand Down Expand Up @@ -204,6 +208,7 @@ def timeIntegration_elementwise(
tau_ou,
sigma_ou,
key,
integration_method
)

# Iterating through time steps
Expand All @@ -222,7 +227,6 @@ def timeIntegration_elementwise(
inh_ou,
)


def get_update_step(
startind,
t,
Expand Down Expand Up @@ -255,6 +259,7 @@ def get_update_step(
tau_ou,
sigma_ou,
key,
integration_method
):
key, subkey_exc = random.split(key)
noise_exc = random.normal(subkey_exc, (N, len(t)))
Expand All @@ -269,7 +274,7 @@ def S_E(x):
def S_I(x):
return 1.0 / (1.0 + jnp.exp(-a_inh * (x - mu_inh)))

def update_step(state, _):
def step_rhs(state):
exc_history, inh_history, exc_ou, inh_ou, i = state

# Vectorized calculation of delayed excitatory input
Expand Down Expand Up @@ -307,18 +312,64 @@ def update_step(state, _):
+ inh_ou # ou noise
)
)

exc_ou_rhs = (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i - startind]
inh_ou_rhs = (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i - startind]

return exc_rhs, inh_rhs, exc_ou_rhs, inh_ou_rhs

def euler(state):
exc_rhs, inh_rhs, exc_ou_rhs, inh_ou_rhs = step_rhs(state)
exc_history, inh_history, exc_ou, inh_ou, i = state
# Euler integration
# make sure e and i variables do not exceed 1 (can only happen with noise)
exc_new = jnp.clip(exc_history[:, -1] + dt * exc_rhs, 0, 1)
inh_new = jnp.clip(inh_history[:, -1] + dt * inh_rhs, 0, 1)

# Update Ornstein-Uhlenbeck process for noise
exc_ou = (
exc_ou + (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i - startind]
exc_ou + exc_ou_rhs
) # mV/ms
inh_ou = (
inh_ou + (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i - startind]
inh_ou + inh_ou_rhs
) # mV/ms

return exc_new, inh_new, exc_ou, inh_ou

def heun(state):
# TODO
exc_k1, inh_k1, exc_ou_rhs, inh_ou_rhs = step_rhs(state)

# Update Ornstein-Uhlenbeck process for noise
exc_ou = (
exc_ou + exc_ou_rhs
) # mV/ms
inh_ou = (
inh_ou + inh_ou_rhs
) # mV/ms

# make sure e and i variables do not exceed 1 (can only happen with noise)
exc_new = jnp.clip(exc_history[:, -1] + dt * exc_rhs, 0, 1)
inh_new = jnp.clip(inh_history[:, -1] + dt * inh_rhs, 0, 1)

exc_k1_history = jnp.concatenate((exc_history[:, 1:], jnp.expand_dims(exc_new, axis=1)), axis=1)
inh_k1_history = jnp.concatenate((inh_history[:, 1:], jnp.expand_dims(inh_new, axis=1)), axis=1)

new_state = exc_k1_history, inh_k1_history, exc_ou, inh_ou
exc_k2, inh_k2, _, _ = step_rhs(new_state)
exc_new = ...
inh_new = ...
return exc_new, inh_new, exc_ou, inh_ou

def update_step(state, _):
exc_history, inh_history, exc_ou, inh_ou, i = state
if integration_method == 'euler':
integration_f = euler
else if integration_method == 'heun':
integration_f = heun
else:
raise Exception(f'Integration method {integration_method} not implemented.')
exc_new, inh_new, exc_ou, inh_ou = integration_f(state)

return (
(
Expand Down
2 changes: 2 additions & 0 deletions neurolib/models/wc/loadDefaultParams.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@ def loadDefaultParams(Cmat=None, Dmat=None, seed=None):
params.exc_ou = np.zeros((params.N,))
params.inh_ou = np.zeros((params.N,))

params.integration_method = 'euler'

return params
3 changes: 3 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_single_node_deterministic(self):
model_jax = WCModel_jax(seed=0)
model_jax.params["duration"] = 1.0 * 1000
model_jax.params["sigma_ou"] = 0.0
model_jax.params['integration_method'] = 'euler'

model_jax.run()

Expand All @@ -48,6 +49,7 @@ def test_single_node_dist(self):
model_jax = WCModel_jax()
model_jax.params["duration"] = 5.0 * 1000
model_jax.params["sigma_ou"] = 0.01
model_jax.params['integration_method'] = 'euler'

model_jax.run()

Expand Down Expand Up @@ -86,6 +88,7 @@ def test_network(self):
model.params["duration"] = 10 * 1000
model.params["sigma_ou"] = 0.0
model.params["K_gl"] = 0.6
model_jax.params['integration_method'] = 'euler'

# local node input parameter
model.params["exc_ext"] = 0.72
Expand Down

0 comments on commit 8628d03

Please sign in to comment.