diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 705108c9a..c0344c962 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -64,7 +64,6 @@ dyn, # dynamics module # delay, # delay module ) -from brainpy._src.delay import (DataDelay, TargetDelay) from brainpy.synapses import ( synouts, # synaptic output diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index b21f2e6b4..b4cb5b21a 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -863,7 +863,7 @@ class PowerLaw(TwoEndConnector): Phys. Rev. E, 65, 026107, 2002. """ - def __init__(self, m, p, directed=False, seed=None, **kwargs): + def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): super(PowerLaw, self).__init__(**kwargs) self.m = m self.p = p diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index 0a52974b4..2f2681b79 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -22,7 +22,7 @@ __all__ = [ 'Delay', - 'TargetDelay', + 'VariableDelay', 'DataDelay', ] @@ -440,7 +440,7 @@ def _check_target_sharding(sharding, ndim, mode: bm.Mode): return sharding -class TargetDelay(Delay): +class VariableDelay(Delay): """Delay variable which has a fixed delay length. The data in this delay variable is arranged as:: @@ -580,7 +580,7 @@ def at(self, entry: str, *indices) -> bm.Array: if entry not in self._registered_entries: raise KeyError(f'Does not find delay entry "{entry}".') delay_step = self._registered_entries[entry] - if delay_step is None: + if delay_step is None or delay_step == 0.: return self.target.value else: assert self.data is not None @@ -691,322 +691,16 @@ def _init_data(self, length: int, batch_size: int = None): self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype) -class DataDelay(Delay): - """Delay variable which has a fixed delay length. - - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Args: - target: Variable. The delay target. - sharding: sequence of str. The name for each axis. - time: int, float. The delay time. - init: Any. The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - entries: optional, dict. The delay access entries. - name: str. The delay name. - method: str. The method used for updating delay. Default None. - mode: Mode. The computing mode. Default None. - - """ - +class DataDelay(VariableDelay): + not_desc_params = ('time', 'entries') def __init__( self, - # delay info - data_size: Union[int, Sequence[int]], - data_type: type, - sharding: Optional[Sequence[str]] = None, - - # delay time - time: Optional[Union[int, float]] = None, - - # delay init - init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, - - # delay access entry - entries: Optional[Dict] = None, - - # delay method - method: Optional[str] = None, - - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(time=time, init=init, method=method, name=name, mode=mode) - - data_size = tools.to_size(data_size) - self.data_size = data_size - self.data_type = data_type - - # sharding - self._sharding = sharding - if sharding is not None: - if len(sharding) == len(data_size): - sharding = list(sharding) - elif len(sharding) + 1 == len(data_size) and self.mode.is_child_of(bm.BatchingMode): - sharding = list(sharding) - sharding.insert(0, bm.sharding.BATCH_AXIS) - else: - raise ValueError('sharding axis names do not match the target dimension. ') - self._target_sharding = tuple(sharding) - if sharding is not None: - sharding = list(sharding) - sharding.insert(0, bm.sharding.TIME_AXIS) - self._data_sharding = tuple(sharding) - - # target - target = variable_(partial(bm.zeros, dtype=data_type), - data_size, - self.mode, - axis_names=sharding, - batch_axis_name=bm.sharding.BATCH_AXIS) - self.target = bm.sharding.partition(bm.asarray(target), self._target_sharding) - - # delay data - self._init = init - if self.max_length > 0: - self._init_data(self.max_length) - else: - self.data = None - - # other info - if entries is not None: - for entry, value in entries.items(): - self.register_entry(entry, value) - - def register_entry( - self, - entry: str, - delay_time: Optional[Union[int, float]], - ) -> 'Delay': - """Register an entry to access the data. - - Args: - entry: str. The entry to access the delay data. - delay_time: The delay time of the entry (can be a float). - - Returns: - Return the self. - """ - if entry in self._registered_entries: - raise KeyError(f'Entry {entry} has been registered.') - - if isinstance(delay_time, (np.ndarray, jax.Array)): - assert delay_time.size == 1 and delay_time.ndim == 0 - delay_time = delay_time.item() - - if delay_time is None: - delay_step = None - delay_time = 0. - else: - assert isinstance(delay_time, (int, float)) - delay_step = math.ceil(delay_time / bm.get_dt()) - - # delay variable - if delay_step is not None: - if self.max_length < delay_step: - self._init_data(delay_step) - self.max_length = delay_step - self.max_time = delay_time - self._registered_entries[entry] = delay_step - return self - - def at(self, entry: str, *indices) -> bm.Array: - """Get the data at the given entry. - - Args: - entry: str. The entry to access the data. - *indices: The slicing indices. - - Returns: - The data. - """ - assert isinstance(entry, str), 'entry should be a string for describing the ' - if entry not in self._registered_entries: - raise KeyError(f'Does not find delay entry "{entry}".') - delay_step = self._registered_entries[entry] - if delay_step is None: - return self.target.value - else: - assert self.data is not None - if delay_step == 0: - return self.target.value - else: - return self.retrieve(delay_step, *indices) - - @property - def delay_target_shape(self): - """The data shape of the delay target.""" - return self.target.shape - - def __repr__(self): - name = self.__class__.__name__ - return f'{name}(step={self.max_length}, shape={self.delay_target_shape}, method={self.method})' - - def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.max_length}. ' - f'But we got {delay_len}') - - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. - - Parameters - ---------- - delay_step: int, ArrayType - The delay length used to retrieve the data. - """ - assert delay_step is not None - if check.is_checking(): - jit_error(delay_step > self.max_length, self._check_delay, delay_step) - - if self.method == ROTATE_UPDATE: - i = share.load('i') - delay_idx = (i + delay_step - 1) % self.max_length - delay_idx = stop_gradient(delay_idx) - - elif self.method == CONCAT_UPDATE: - delay_idx = delay_step - - else: - raise ValueError(f'Unknown updating method "{self.method}"') - - # the delay index - if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') - indices = (delay_idx,) + tuple(indices) - - # the delay data - return self.data[indices] - - def update( - self, - latest_value: Optional[Union[bm.Array, jax.Array]] = None - ) -> None: - """Update delay variable with the new data. - """ - if self.data is not None: - # get the latest target value - if latest_value is None: - latest_value = self.target.value - - # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: - i = share.load('i') - idx = bm.as_jax((i - 1) % self.max_length) - self.data[idx] = latest_value - - # update the delay data at the first position - elif self.method == CONCAT_UPDATE: - if self.max_length > 1: - self.data.value = bm.vstack([latest_value, self.data[1:]]) - else: - self.data[0] = latest_value - - def reset_state(self, batch_size: int = None): - """Reset the delay data. - """ - # initialize delay data - if self.data is not None: - self._init_data(self.max_length, batch_size) - - def _init_data(self, length: int, batch_size: int = None): - if batch_size is not None: - if self.target.batch_size != batch_size: - raise ValueError(f'The batch sizes of delay variable and target variable differ ' - f'({self.target.batch_size} != {batch_size}). ' - 'Please reset the target variable first, because delay data ' - 'depends on the target variable. ') - - if self.target.batch_axis is None: - batch_axis = None - else: - batch_axis = self.target.batch_axis + 1 - - f = jax.jit(jnp.zeros, - static_argnums=0, - static_argnames='dtype', - out_shardings=bm.sharding.get_sharding(self._data_sharding)) - data = f((length,) + self.target.shape, dtype=self.target.dtype) - self.data = bm.Variable(data, batch_axis=batch_axis) - # update delay data - if isinstance(self._init, (bm.Array, jax.Array, numbers.Number)): - self.data[:] = self._init - elif callable(self._init): - self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype) - - def _init_target(self): - target = variable_(partial(bm.zeros, dtype=self.data_type), - self.data_size, - self.mode, - axis_names=self._sharding, - batch_axis_name=bm.sharding.BATCH_AXIS) - self.target = bm.sharding.partition(bm.asarray(target), self._target_sharding) - - -class _DataDelay1(TargetDelay): - """Delay variable which has a fixed delay length. - - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Args: - size: int, sequence of int. The delay target size. - sharding: sequence of str. The name for each axis. - time: optional, int, float. The delay time. Default is None. - dtype: type. The data type. - init: Any. The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - entries: optional, dict. The delay access entries. - name: str. The delay name. - method: str. The method used for updating delay. Default None. - mode: Mode. The computing mode. Default None. - - """ - - not_desc_params = ('time', 'entries') - - def __init__( - self, - - # delay info - size: Union[int, Sequence[int]], - sharding: Optional[Sequence[str]] = None, - dtype: Optional[type] = None, + # delay target + target: bm.Variable, + target_init: Callable, # delay time time: Optional[Union[int, float]] = None, @@ -1024,28 +718,8 @@ def __init__( name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): - size = tools.to_size(size) - mode = mode if mode is not None else bm.get_mode() - if sharding is not None: - assert len(size) == len(sharding) - if isinstance(mode, bm.BatchingMode): - batch_axis = 0 - size = (mode.batch_size,) + size - else: - batch_axis = None - - target = bm.Variable(bm.zeros(size, dtype=dtype), batch_axis=batch_axis) - if init is None: - pass - elif isinstance(init, (bm.Array, jax.Array, numbers.Number)): - target[:] = self._init - elif callable(self._init): - target[:] = self._init(size, dtype=dtype) - else: - raise ValueError - + self.target_init = target_init super().__init__(target=target, - sharding=sharding, time=time, init=init, entries=entries, @@ -1053,25 +727,18 @@ def __init__( name=name, mode=mode) + def reset_state(self, batch_size: int = None): + """Reset the delay data. + """ + self.target.value = variable_(self.target_init, self.target.size_without_batch, batch_size) + if self.data is not None: + self._init_data(self.max_length, batch_size) + def update( self, latest_value: Union[bm.Array, jax.Array] ) -> None: """Update delay variable with the new data. """ - # get the latest target value self.target.value = latest_value - - if self.data is not None: - # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: - i = share.load('i') - idx = bm.as_jax((i - 1) % (self.max_length + 1)) - self.data[idx] = latest_value - - # update the delay data at the first position - elif self.method == CONCAT_UPDATE: - if self.max_length >= 2: - self.data.value = bm.vstack([latest_value, self.data[1:]]) - else: - self.data[0] = latest_value + super().update(latest_value) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 51312e9c0..39636562a 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- -from typing import Dict, Optional, Union, Callable, Sequence +from typing import Dict, Optional, Union, Callable +import jax import jax.numpy as jnp from brainpy import math as bm @@ -16,13 +17,14 @@ from .base import Layer __all__ = [ - 'Dense', - 'Linear', + 'Dense', 'Linear', 'Identity', - 'AllToAll', 'OneToOne', - 'MaskedDense', + 'MaskedLinear', + 'CSRLinear', 'EventCSRLinear', + 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', + 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', ] @@ -83,11 +85,13 @@ def __init__( is_initializer(b_initializer, 'bias_initializer', allow_none=True) # parameter initialization - self.W = parameter(self.weight_initializer, (num_in, self.num_out)) - self.b = parameter(self.bias_initializer, (self.num_out,)) + W = parameter(self.weight_initializer, (num_in, self.num_out)) + b = parameter(self.bias_initializer, (self.num_out,)) if isinstance(self.mode, bm.TrainingMode): - self.W = bm.TrainVar(self.W) - self.b = None if (self.b is None) else bm.TrainVar(self.b) + W = bm.TrainVar(W) + b = None if (b is None) else bm.TrainVar(b) + self.W = W + self.b = b # fitting parameters self.online_fit_by = None @@ -214,22 +218,6 @@ def update(self, x): return x -class CSRLinear(Layer): - pass - - -class CSCLinear(Layer): - pass - - -class BSRLinear(Layer): - pass - - -class MatLinear(Layer): - pass - - class AllToAll(Layer): """Synaptic matrix multiplication with All2All connections. @@ -250,7 +238,6 @@ def __init__( weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, include_self: bool = True, - mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): @@ -261,17 +248,18 @@ def __init__( self.include_self = include_self self.sharding = sharding - self.weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding) + weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding) if isinstance(self.mode, bm.TrainingMode): - self.weight = bm.TrainVar(self.weight) + weight = bm.TrainVar(weight) + self.weight = weight def update(self, pre_val): if bm.ndim(self.weight) == 0: # weight is a scalar if isinstance(self.mode, bm.BatchingMode): - assert pre_val.ndim == 2 + assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.' post_val = bm.sum(pre_val, keepdims=True, axis=1) else: - assert pre_val.ndim == 1 + assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.' post_val = bm.sum(pre_val) if not self.include_self: if self.num_pre == self.num_post: @@ -285,6 +273,7 @@ def update(self, pre_val): post_val = self.weight * post_val else: # weight is a matrix + assert self.weight.ndim == 2, '"weight" must be a 2D matrix.' if not self.include_self: post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False) else: @@ -317,30 +306,16 @@ def __init__( self.num = num self.sharding = sharding - self.weight = init.parameter(weight, (self.num,), sharding=sharding) + weight = init.parameter(weight, (self.num,), sharding=sharding) if isinstance(self.mode, bm.TrainingMode): - self.weight = bm.TrainVar(self.weight) + weight = bm.TrainVar(weight) + self.weight = weight def update(self, pre_val): return pre_val * self.weight -class _SynMatMul(Layer): - def __init__( - self, - conn: connect.TwoEndConnector, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding - - -class MaskedDense(_SynMatMul): +class MaskedLinear(Layer): r"""Synaptic matrix multiplication with dense computation. It performs the computation of: @@ -368,7 +343,11 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): - super().__init__(name=name, mode=mode, conn=mask) + super().__init__(name=name, mode=mode) + + assert isinstance(mask, connect.TwoEndConnector) + self.conn = mask + self.sharding = sharding # weight weight = init.parameter(weight, (mask.pre_num, mask.post_num), sharding=sharding) @@ -383,7 +362,7 @@ def update(self, x): return x @ (self.weight * self.mask) -class CsrMM(_SynMatMul): +class CSRLinear(Layer): r"""Synaptic matrix multiplication with CSR sparse computation. It performs the computation of: @@ -410,22 +389,49 @@ def __init__( sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, + method: str = 'cusparse', + transpose: bool = True, ): - super().__init__(name=name, mode=mode, conn=conn) + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + assert sharding is None, 'Currently this model does not support sharding.' + self.conn = conn + self.sharding = sharding + self.method = method + self.transpose = transpose # connection self.indices, self.indptr = self.conn.require('csr') # weight - self.weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding) + weight = init.parameter(weight, (self.indices.size,)) if isinstance(self.mode, bm.TrainingMode): - self.weight = bm.TrainVar(self.weight) + weight = bm.TrainVar(weight) + self.weight = weight def update(self, x): - raise NotImplementedError + if x.ndim == 1: + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose, + method=self.method) + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose, + method=self.method) -class CscMM(_SynMatMul): +class CSCLinear(Layer): r"""Synaptic matrix multiplication with CSC sparse computation. It performs the computation of: @@ -453,14 +459,79 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): - super().__init__(name=name, mode=mode, conn=conn) + super().__init__(name=name, mode=mode) + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding -class EventCsrMM(_SynMatMul): - pass +class EventCSRLinear(Layer): + r"""Synaptic matrix multiplication with event CSR sparse computation. -class BcsrMM(_SynMatMul): + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + assert sharding is None, 'Currently this model does not support sharding.' + self.conn = conn + self.sharding = sharding + self.transpose = transpose + + # connection + self.indices, self.indptr = self.conn.require('csr') + + # weight + weight = init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) + + +class BcsrMM(Layer): r"""Synaptic matrix multiplication with BCSR sparse computation. It performs the computation of: @@ -488,10 +559,14 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): - super().__init__(name=name, mode=mode, conn=conn) + super().__init__(name=name, mode=mode) + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding -class BcscMM(_SynMatMul): + +class BcscMM(Layer): r"""Synaptic matrix multiplication with BCSC sparse computation. It performs the computation of: @@ -519,28 +594,486 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): - super().__init__(name=name, mode=mode, conn=conn) + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding + + +class JitFPHomoLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: int, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPUniformLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: int, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError -class JitProbHomoMM(_SynMatMul): - pass + def _batch_mv(self, x): + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) -class JitProbUniformMM(_SynMatMul): - pass +class JitFPNormalLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + It performs the computation of: -class JitProbNormalMM(_SynMatMul): - pass + .. math:: + y = x @ M -class EventJitProbHomMM(_SynMatMul): - pass + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ -class EventJitProbUniformMM(_SynMatMul): - pass + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: int, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPHomoLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: int, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPUniformLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: int, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPNormalLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: int, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError -class EventJitProbNormalMM(_SynMatMul): - pass + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py new file mode 100644 index 000000000..337536fd2 --- /dev/null +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -0,0 +1,172 @@ +import brainpy as bp +from absl.testing import parameterized + +import brainpy.math as bm + + +class TestLinear(parameterized.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bm.random.seed() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + num_out=[20, 10, 5] + ) + def test_Dense1(self, size, num_out): + f = bp.dnn.Linear(10, num_out) + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size[:-1] + (num_out,)) + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + ) + def test_Identity(self, size): + f = bp.dnn.Identity() + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size) + + def test_AllToAll1(self): + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((8, 10)) + y = f(x) + expected = bm.sum(x, axis=1, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((10,)) + y = f(x) + expected = bm.sum(x, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + def test_OneToOne(self): + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((8, 10)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((10,)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + @parameterized.product( + conn=[ + # bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_MaskedLinear(self, conn): + bm.random.DEFAULT.seed(123) + f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_CSRLinear(self, conn): + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPHomoLinear(self, prob, weight, shape): + f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): + f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPHomoLinear(self, prob, weight, shape): + f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): + f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + + diff --git a/brainpy/_src/dyn/projections.py b/brainpy/_src/dyn/projections.py index accfe8a29..db84f3668 100644 --- a/brainpy/_src/dyn/projections.py +++ b/brainpy/_src/dyn/projections.py @@ -1,7 +1,7 @@ from typing import Optional, Callable, Union from brainpy import math as bm -from brainpy._src.delay import Delay, TargetDelay +from brainpy._src.delay import Delay, VariableDelay, DataDelay from brainpy._src.dyn.base import NeuDyn, SynOut from brainpy._src.dynsys import DynamicalSystemNS, DynamicalSystem from brainpy._src.mixin import DelayedInit, ReturnInfo, ProjAutoDelay @@ -42,7 +42,7 @@ def update(self, *args, **kwargs): def _init_delay(info: Union[bm.Variable, ReturnInfo]) -> Delay: if isinstance(info, bm.Variable): - target = info + return VariableDelay(info) elif isinstance(info, ReturnInfo): if isinstance(info.batch_or_mode, int): size = (info.batch_or_mode,) + tuple(info.size) @@ -59,9 +59,10 @@ def _init_delay(info: Union[bm.Variable, ReturnInfo]) -> Delay: target = bm.Variable(info.init(size), batch_axis=batch_axis, axis_names=info.axis_names) + return DataDelay(target, target_init=info.init) else: raise TypeError - return TargetDelay(target) + class ProjAlignPre(SynProj): diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 7164f0457..1eb5bb3cd 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -168,28 +168,31 @@ def __call__(self, *args, **kwargs): if share is None: from brainpy._src.context import share - if self._pass_shared_args: - if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): - if len(args) and isinstance(args[0], dict): + try: + if self._pass_shared_args: + if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): + if len(args) and isinstance(args[0], dict): + share.save(**args[0]) + return self.update(*args[1:], **kwargs) + else: + return self.update(*args, **kwargs) + else: + if len(args) and isinstance(args[0], dict): + return self.update(*args, **kwargs) + else: + # If first argument is not shared argument, + # we should get the shared arguments from the global context. + # However, users should set and update shared arguments + # in the global context when using this mode. + return self.update(share.get_shargs(), *args, **kwargs) + else: + if len(args) and isinstance(args[0], dict): # it may be shared arguments share.save(**args[0]) return self.update(*args[1:], **kwargs) else: return self.update(*args, **kwargs) - else: - if len(args) and isinstance(args[0], dict): - return self.update(*args, **kwargs) - else: - # If first argument is not shared argument, - # we should get the shared arguments from the global context. - # However, users should set and update shared arguments - # in the global context when using this mode. - return self.update(share.get_shargs(), *args, **kwargs) - else: - if len(args) and isinstance(args[0], dict): # it may be shared arguments - share.save(**args[0]) - return self.update(*args[1:], **kwargs) - else: - return self.update(*args, **kwargs) + except Exception as e: + raise RuntimeError(f'Error occurs when running {self.name}: {self}') from e def register_delay( self, diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 3951d37e8..7a10a8227 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -218,6 +218,14 @@ def __init__( axis_names = tuple(axis_names) self.axis_names = axis_names + @property + def size_without_batch(self): + if self.batch_axis is None: + return self.size + else: + sizes = self.size + return sizes[:self.batch_size] + sizes[self.batch_axis + 1:] + @property def batch_axis(self) -> Optional[int]: return self._batch_axis diff --git a/brainpy/_src/synapses/abstract_models.py b/brainpy/_src/synapses/abstract_models.py index 55ee0980d..4f82392db 100644 --- a/brainpy/_src/synapses/abstract_models.py +++ b/brainpy/_src/synapses/abstract_models.py @@ -343,7 +343,7 @@ def update(self, tdi, pre_spike=None): post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': - f = lambda s: bm.event_csr_matvec( + f = lambda s: bm.event.csrmv( self.g_max, self.conn_mask[0], self.conn_mask[1], s, shape=(self.pre.num, self.post.num), transpose=True @@ -548,7 +548,7 @@ def update(self, tdi, pre_spike=None): post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': - f = lambda s: bm.cusparse_csr_matvec( + f = lambda s: bm.sparse.csrmv( self.g_max, self.conn_mask[0], self.conn_mask[1], s, shape=(self.pre.num, self.post.num), transpose=True @@ -893,7 +893,7 @@ def update(self, tdi, pre_spike=None): post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': - f = lambda s: bm.event_csr_matvec( + f = lambda s: bm.event.csrmv( self.g_max, self.conn_mask[0], self.conn_mask[1], s, shape=(self.pre.num, self.post.num), transpose=True diff --git a/brainpy/_src/synapses/biological_models.py b/brainpy/_src/synapses/biological_models.py index c4b126c68..9bf9c1c03 100644 --- a/brainpy/_src/synapses/biological_models.py +++ b/brainpy/_src/synapses/biological_models.py @@ -229,7 +229,7 @@ def update(self, tdi, pre_spike=None): post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': - f = lambda s: bm.cusparse_csr_matvec( + f = lambda s: bm.sparse.csrmv( self.g_max, self.conn_mask[0], self.conn_mask[1], s, shape=(self.pre.num, self.post.num), transpose=True @@ -573,7 +573,7 @@ def update(self, tdi, pre_spike=None): post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': - f = lambda s: bm.cusparse_csr_matvec( + f = lambda s: bm.sparse.csrmv( self.g_max,self.conn_mask[0], self.conn_mask[1], s, shape=(self.pre.num, self.post.num), transpose=True diff --git a/brainpy/dnn.py b/brainpy/dnn.py index 865aa7bd0..8c0019a91 100644 --- a/brainpy/dnn.py +++ b/brainpy/dnn.py @@ -5,6 +5,7 @@ Layer as Layer, ) + from brainpy._src.dnn.conv import ( Conv1d as Conv1d, Conv2d as Conv2d, @@ -22,22 +23,32 @@ Dropout as Dropout, ) + from brainpy._src.dnn.function import ( Activation as Activation, Flatten as Flatten, FunAsLayer as FunAsLayer, ) + from brainpy._src.dnn.linear import ( Dense as Dense, Linear as Linear, Identity as Identity, - AllToAll as AllToAll, OneToOne as OneToOne, - MaskedDense as MaskedDense, + MaskedLinear as MaskedLinear, + CSRLinear as CSRLinear, + EventCSRLinear as EventCSRLinear, + JitFPHomoLinear as JitFPHomoLinear, + JitFPUniformLinear as JitFPUniformLinear, + JitFPNormalLinear as JitFPNormalLinear, + EventJitFPHomoLinear as EventJitFPHomoLinear, + EventJitFPNormalLinear as EventJitFPNormalLinear, + EventJitFPUniformLinear as EventJitFPUniformLinear, ) + from brainpy._src.dnn.normalization import ( BatchNorm1d as BatchNorm1d, BatchNorm2d as BatchNorm2d, @@ -50,10 +61,12 @@ InstanceNorm as InstanceNorm, ) + from brainpy._src.dnn.nvar import ( NVAR as NVAR, ) + from brainpy._src.dnn.pooling import ( MaxPool as MaxPool, MaxPool1d as MaxPool1d, @@ -75,10 +88,12 @@ AdaptiveMaxPool3d as AdaptiveMaxPool3d, ) + from brainpy._src.dnn.reservoir import ( Reservoir as Reservoir, ) + from brainpy._src.dnn.rnncells import ( RNNCell as RNNCell, GRUCell as GRUCell, @@ -88,6 +103,7 @@ Conv3dLSTMCell as Conv3dLSTMCell, ) + from brainpy._src.dnn.interoperation_flax import ( FromFlax, ToFlaxRNNCell, ToFlax, diff --git a/examples/dynamics_simulation/COBA-v2.py b/examples/dynamics_simulation/COBA-v2.py new file mode 100644 index 000000000..0a9077e66 --- /dev/null +++ b/examples/dynamics_simulation/COBA-v2.py @@ -0,0 +1,169 @@ +import brainpy as bp + +neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + + +class EICOBA_PreAlign(bp.DynamicalSystemNS): + def __init__(self, num_exc, num_inh, inp=20.): + super().__init__() + + self.inp = inp + self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars) + self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars) + + self.E2I = bp.dyn.ProjAlignPre( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=None, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.I.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.I, + ) + self.E2E = bp.dyn.ProjAlignPre( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=None, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.E.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.E, + ) + self.I2E = bp.dyn.ProjAlignPre( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=None, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.E.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.ProjAlignPre( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=0., + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.I.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + + +class EICOBA_PostAlign(bp.DynamicalSystemNS): + def __init__(self, num_exc, num_inh, inp=20.): + super().__init__() + self.inp = inp + + self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars) + self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars) + + self.E2E = bp.dyn.ProjAlignPost( + pre=self.E, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.E.num), 0.6), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E, + ) + self.E2I = bp.dyn.ProjAlignPost( + pre=self.E, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.I.num), 0.6), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I, + ) + self.I2E = bp.dyn.ProjAlignPost( + pre=self.I, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.E.num), 6.7), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.ProjAlignPost( + pre=self.I, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.I.num), 6.7), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + + +class EINet(bp.Network): + def __init__(self, scale=1.0, method='exp_auto'): + # network size + num_exc = int(3200 * scale) + num_inh = int(800 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + E = bp.neurons.LIF(num_exc, **pars, method=method) + I = bp.neurons.LIF(num_inh, **pars, method=method) + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), + g_max=we, tau=5., method=method, + output=bp.synouts.COBA(E=0.)) + E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), + g_max=we, tau=5., method=method, + output=bp.synouts.COBA(E=0.)) + I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), + g_max=wi, tau=10., method=method, + output=bp.synouts.COBA(E=-80.)) + I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), + g_max=wi, tau=10., method=method, + output=bp.synouts.COBA(E=-80.)) + + super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) + + +# num_device = 8 +# bm.set_host_device_count(num_device) +# bm.sharding.set(mesh_axes=(bp.dyn.PNEU_AXIS,), mesh_shape=(num_device, )) + +def run3(): + net = EICOBA_PreAlign(3200, 800) + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}) + print(runner.run(100., eval_time=True)) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) + + +def run1(): + net = EICOBA_PostAlign(3200, 800) + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}) + print(runner.run(100., eval_time=True)) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) + + +def run2(): + net = EINet() + runner = bp.DSRunner(net, + monitors=['E.spike'], + inputs=[('E.input', 20.), ('I.input', 20.)]) + r = runner.run(100., eval_time=True) + print(r) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) + + +if __name__ == '__main__': + # run1() + # run2() + run3()