Skip to content

Commit

Permalink
fix bugs on delay updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Apr 18, 2022
1 parent 8dd174e commit 23c4ae9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
3 changes: 2 additions & 1 deletion brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from brainpy.errors import ModelBuildError
from brainpy.integrators.base import Integrator
from brainpy.types import Tensor
from .utils import init_delay

__all__ = [
'DynamicalSystem',
Expand Down Expand Up @@ -468,6 +467,7 @@ def register_delay(
# delay variable
if domain == 'local':
self.local_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
self.register_implicit_nodes(self.local_delay_vars)
else:
if name not in self.global_delay_vars:
self.global_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
Expand All @@ -477,6 +477,7 @@ def register_delay(
else:
if self.global_delay_vars[name].num_delay_step - 1 < max_delay_step:
self.global_delay_vars[name].init(delay_target, max_delay_step, initial_delay_data)
self.register_implicit_nodes(self.global_delay_vars)
return delay_step

def get_delay(
Expand Down
13 changes: 6 additions & 7 deletions brainpy/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,12 @@ def __init__(
delay_target=self.pre.spike)

def update(self, _t, _dt):
# get delay
# delays
if self.delay_step is None:
pre_spike = self.pre.spike
else:
pre_spike = self.get_delay(self.pre.name + '.spike', delay_step=self.delay_step)
# update delay
self.update_delay(self.pre.name + '.spike', delay_target=self.pre.spike)
self.update_delay(self.pre.name + '.spike', delay_target=self.pre.spike)

# post values
assert self.weight_type in ['homo', 'heter']
Expand Down Expand Up @@ -321,7 +320,7 @@ def update(self, _t, _dt):
delayed_pre_spike = self.pre.spike
else:
delayed_pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step)
self.update_delay(self.pre.name + '.spike', self.pre.spike)
self.update_delay(self.pre.name + '.spike', self.pre.spike)

# post values
if isinstance(self.conn, All2All):
Expand Down Expand Up @@ -451,7 +450,7 @@ def update(self, _t, _dt):
delayed_spike = self.pre.spike
else:
delayed_spike = self.get_delay(self.pre.name + '.spike', self.delay_step)
self.update_delay(self.pre.name + '.spike', self.pre.spike)
self.update_delay(self.pre.name + '.spike', self.pre.spike)

# post values
if isinstance(self.conn, All2All):
Expand Down Expand Up @@ -622,7 +621,7 @@ def update(self, _t, _dt):
delayed_pre_spike = self.pre.spike
else:
delayed_pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step)
self.update_delay(self.pre.name + '.spike', self.pre.spike)
self.update_delay(self.pre.name + '.spike', self.pre.spike)

# post-synaptic values
self.g.value, self.h.value = self.integral(self.g, self.h, _t, _dt)
Expand Down Expand Up @@ -1130,7 +1129,7 @@ def update(self, _t, _dt):
delayed_pre_spike = self.pre.spike
else:
delayed_pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step)
self.update_delay(self.pre.name + '.spike', self.pre.spike)
self.update_delay(self.pre.name + '.spike', self.pre.spike)

# post-synaptic value
self.g.value, self.x.value = self.integral(self.g, self.x, _t, dt=_dt)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/dyn/synapses/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def update(self, _t, _dt):
pre_spike = self.pre.spike
else:
pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step)
self.update_delay(self.pre.name + '.spike', self.pre.spike)
self.update_delay(self.pre.name + '.spike', self.pre.spike)

# spike arrival time
self.spike_arrival_time.value = bm.where(pre_spike, _t, self.spike_arrival_time)
Expand Down

0 comments on commit 23c4ae9

Please sign in to comment.