Skip to content

Commit

Permalink
Merge branch 'master' into updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Apr 22, 2022
2 parents 1df9c76 + 9bb11a8 commit 1c2df42
Show file tree
Hide file tree
Showing 6 changed files with 492 additions and 364 deletions.
8 changes: 4 additions & 4 deletions brainpy/connect/custom_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, conn_mat):
self.pre_num, self.post_num = conn_mat.shape
self.pre_size, self.post_size = (self.pre_num,), (self.post_num,)

self.conn_mat = np.asarray(conn_mat, dtype=MAT_DTYPE)
self.conn_mat = np.asarray(conn_mat).astype(MAT_DTYPE)

def __call__(self, pre_size, post_size):
assert self.pre_num == tools.size2num(pre_size)
Expand All @@ -47,8 +47,8 @@ def __init__(self, i, j):
assert i.size == j.size

# initialize the class via "pre_ids" and "post_ids"
self.pre_ids = np.asarray(i, dtype=IDX_DTYPE)
self.post_ids = np.asarray(j, dtype=IDX_DTYPE)
self.pre_ids = np.asarray(i).astype(IDX_DTYPE)
self.post_ids = np.asarray(j).astype(IDX_DTYPE)

def __call__(self, pre_size, post_size):
super(IJConn, self).__call__(pre_size, post_size)
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self, csr_mat):
f'Please run "pip install scipy" to install scipy.')

assert isinstance(csr_mat, csr_matrix)
csr_mat.data = np.asarray(csr_mat.data, dtype=MAT_DTYPE)
csr_mat.data = np.asarray(csr_mat.data).astype(MAT_DTYPE)
self.csr_mat = csr_mat
self.pre_num, self.post_num = csr_mat.shape

Expand Down
17 changes: 12 additions & 5 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,19 @@ def register_delay(
def get_delay(
self,
name: str,
delay_step: Union[int, bm.JaxArray, bm.ndarray]
delay_step: Union[int, bm.JaxArray, bm.ndarray],
indices=None,
):
"""Get delay data according to the delay times.
"""Get delay data according to the provided delay steps.
Parameters
----------
name: str
The delay variable name.
delay_step: int, JaxArray, ndarray
The delay length.
indices: optional, JaxArray, ndarray
The indices of the delay.
Returns
-------
Expand All @@ -501,14 +504,18 @@ def get_delay(
"""
if name in self.global_delay_vars:
if isinstance(delay_step, int):
return self.global_delay_vars[name](delay_step)
return self.global_delay_vars[name](delay_step, indices)
else:
return self.global_delay_vars[name](delay_step, jnp.arange(delay_step.size))
if indices is None:
indices = jnp.arange(delay_step.size)
return self.global_delay_vars[name](delay_step, indices)
elif name in self.local_delay_vars:
if isinstance(delay_step, int):
return self.local_delay_vars[name](delay_step)
else:
return self.local_delay_vars[name](delay_step, jnp.arange(delay_step.size))
if indices is None:
indices = jnp.arange(delay_step.size)
return self.local_delay_vars[name](delay_step, indices)
else:
raise ValueError(f'{name} is not defined in delay variables.')

Expand Down
Loading

0 comments on commit 1c2df42

Please sign in to comment.