diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml
index df2ef61b0..2883600b3 100644
--- a/.github/workflows/CI-models.yml
+++ b/.github/workflows/CI-models.yml
@@ -27,12 +27,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
- pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
@@ -51,7 +50,7 @@ jobs:
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
+# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
@@ -75,12 +74,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
- pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
@@ -99,7 +97,7 @@ jobs:
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
+# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
@@ -124,13 +122,12 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install numpy>=1.21.0
- pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
pip uninstall brainpy -y
@@ -150,7 +147,7 @@ jobs:
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
+# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 01b5047ec..84aa028e3 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -29,14 +29,13 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
- pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
@@ -62,7 +61,7 @@ jobs:
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
+# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
@@ -96,14 +95,13 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
- pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
@@ -128,7 +126,7 @@ jobs:
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
+# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
@@ -158,23 +156,19 @@ jobs:
# strategy:
# fail-fast: false
# matrix:
-# python-version: ["3.8", "3.9", "3.10", "3.11"]
+# python-version: ["3.9", "3.10", "3.11"]
#
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
+# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# python -m pip install flake8 pytest
-# python -m pip install numpy>=1.21.0
-# python -m pip install "jaxlib==0.4.11" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
-# python -m pip install jax==0.4.11
# python -m pip install -r requirements-dev.txt
-# python -m pip install tqdm brainpylib
# pip uninstall brainpy -y
# python setup.py install
# - name: Lint with flake8
@@ -199,7 +193,7 @@ jobs:
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
+# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
diff --git a/.github/workflows/Publish.yml b/.github/workflows/Publish.yml
index b00b1f1b5..fd377770e 100644
--- a/.github/workflows/Publish.yml
+++ b/.github/workflows/Publish.yml
@@ -10,7 +10,7 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
- - run: python setup.py bdist_wheel
+ - run: python setup.py bdist_wheel --python-tag=py3
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1
with:
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 2d4189809..0c515d77a 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -18,7 +18,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- - uses: conda-incubator/setup-miniconda@v2
+ - uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
python-version: "3.10"
diff --git a/README.md b/README.md
index 716dbd900..6d2ee4bf4 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@
-BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
+BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Taichi](https://github.com/taichi-dev/taichi), [Numba](https://github.com/numba/numba), and others). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
- **Source**: https://github.com/brainpy/BrainPy
@@ -77,7 +77,9 @@ We provide a Binder environment for BrainPy. You can use the following button to
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
+- [《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling)
- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course)
+- [第二届神经计算建模与编程培训班 (Second Training Course on Neural Modeling and Programming)](https://github.com/brainpy/2nd-neural-modeling-and-programming-course)
## Citing
@@ -102,4 +104,6 @@ We also welcome your contributions
- [ ] pipeline parallelization on multiple devices for sparse spiking network models
- [ ] multi-compartment modeling
- [ ] measurements, analysis, and visualization methods for large-scale spiking data
+- [ ] Online learning methods for large-scale spiking network models
+- [ ] Classical plasticity rules for large-scale spiking network models
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 371ed6b27..a3a1de694 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-__version__ = "2.4.6"
+
+__version__ = "2.5.0"
# fundamental supporting modules
from brainpy import errors, check, tools
@@ -75,9 +76,10 @@
)
NeuGroup = NeuGroupNS = dyn.NeuDyn
-# shared parameters
+# common tools
from brainpy._src.context import (share as share)
-from brainpy._src.helpers import (reset_state as reset_state,
+from brainpy._src.helpers import (reset_level as reset_level,
+ reset_state as reset_state,
save_state as save_state,
load_state as load_state,
clear_input as clear_input)
diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py
index 17edcff31..d04c3aa2e 100644
--- a/brainpy/_add_deprecations.py
+++ b/brainpy/_add_deprecations.py
@@ -88,6 +88,16 @@
# neurons
'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.dyn.NeuDyn', NeuDyn),
+ # projections
+ 'ProjAlignPostMg1': ('brainpy.dyn.ProjAlignPostMg1', 'brainpy.dyn.HalfProjAlignPostMg', dyn.HalfProjAlignPostMg),
+ 'ProjAlignPostMg2': ('brainpy.dyn.ProjAlignPostMg2', 'brainpy.dyn.FullProjAlignPostMg', dyn.FullProjAlignPostMg),
+ 'ProjAlignPost1': ('brainpy.dyn.ProjAlignPost1', 'brainpy.dyn.HalfProjAlignPost', dyn.HalfProjAlignPost),
+ 'ProjAlignPost2': ('brainpy.dyn.ProjAlignPost2', 'brainpy.dyn.FullProjAlignPost', dyn.FullProjAlignPost),
+ 'ProjAlignPreMg1': ('brainpy.dyn.ProjAlignPreMg1', 'brainpy.dyn.FullProjAlignPreSDMg', dyn.FullProjAlignPreSDMg),
+ 'ProjAlignPreMg2': ('brainpy.dyn.ProjAlignPreMg2', 'brainpy.dyn.FullProjAlignPreDSMg', dyn.FullProjAlignPreDSMg),
+ 'ProjAlignPre1': ('brainpy.dyn.ProjAlignPre1', 'brainpy.dyn.FullProjAlignPreSD', dyn.FullProjAlignPreSD),
+ 'ProjAlignPre2': ('brainpy.dyn.ProjAlignPre2', 'brainpy.dyn.FullProjAlignPreDS', dyn.FullProjAlignPreDS),
+
# synapses
'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
diff --git a/brainpy/_src/checkpoints/serialization.py b/brainpy/_src/checkpoints/serialization.py
index d12f5a1c8..a19a2b68e 100644
--- a/brainpy/_src/checkpoints/serialization.py
+++ b/brainpy/_src/checkpoints/serialization.py
@@ -19,21 +19,19 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import jax
-import msgpack
import numpy as np
+from jax import monitoring
from jax import process_index
from jax.experimental.multihost_utils import sync_global_devices
-
try:
- from jax import monitoring
-except (ModuleNotFoundError, ImportError):
- monitoring = None
+ from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
+except:
+ get_tensorstore_spec = GlobalAsyncCheckpointManager = None
try:
- from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
-except (ModuleNotFoundError, ImportError):
- get_tensorstore_spec = None
- GlobalAsyncCheckpointManager = None
+ import msgpack
+except ModuleNotFoundError:
+ msgpack = None
from brainpy._src.math.ndarray import Array
from brainpy.errors import (AlreadyExistsError,
@@ -116,6 +114,12 @@ def _record_path(name):
_error_context.path.pop()
+def check_msgpack():
+ if msgpack is None:
+ raise ModuleNotFoundError('\nbrainpy.checkpoints needs "msgpack" package. Please install msgpack via:\n'
+ '> pip install msgpack')
+
+
def current_path():
"""Current state_dict path during deserialization for error messages."""
return '/'.join(_error_context.path)
@@ -1126,6 +1130,7 @@ def save(
out: str
Filename of saved checkpoint.
"""
+ check_msgpack()
start_time = time.time()
# Make sure all saves are finished before the logic of checking and removing
# outdated checkpoints happens.
@@ -1257,6 +1262,7 @@ def save_pytree(
out: str
Filename of saved checkpoint.
"""
+ check_msgpack()
if verbose:
print(f'Saving checkpoint into {filename}')
start_time = time.time()
@@ -1344,6 +1350,7 @@ def multiprocess_save(
out: str
Filename of saved checkpoint.
"""
+ check_msgpack()
start_time = time.time()
# Make sure all saves are finished before the logic of checking and removing
# outdated checkpoints happens.
@@ -1488,6 +1495,7 @@ def load(
returned. This is to match the behavior of the case where a directory path
is specified but the directory has not yet been created.
"""
+ check_msgpack()
start_time = time.time()
ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str
@@ -1582,6 +1590,7 @@ def load_pytree(
returned. This is to match the behavior of the case where a directory path
is specified but the directory has not yet been created.
"""
+ check_msgpack()
start_time = time.time()
if not os.path.exists(filename):
raise ValueError(f'Checkpoint not found: {filename}')
diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py
index 33456c02f..3bba20a79 100644
--- a/brainpy/_src/dependency_check.py
+++ b/brainpy/_src/dependency_check.py
@@ -1,38 +1,41 @@
+import os
+import sys
from jax.lib import xla_client
-
__all__ = [
'import_taichi',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]
-
-_minimal_brainpylib_version = '0.1.10'
+_minimal_brainpylib_version = '0.2.6'
_minimal_taichi_version = (1, 7, 0)
taichi = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None
+taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
+ f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
+ '> pip install taichi==1.7.0')
+os.environ["TI_LOG_LEVEL"] = "error"
+
def import_taichi():
global taichi
if taichi is None:
- try:
- import taichi as taichi # noqa
- except ModuleNotFoundError:
- raise ModuleNotFoundError(
- 'Taichi is needed. Please install taichi through:\n\n'
- '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
- )
-
- if taichi.__version__ < _minimal_taichi_version:
- raise RuntimeError(
- f'We need taichi>={_minimal_taichi_version}. '
- f'Currently you can install taichi>={_minimal_taichi_version} through taichi-nightly:\n\n'
- '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
- )
+ with open(os.devnull, 'w') as devnull:
+ old_stdout = sys.stdout
+ sys.stdout = devnull
+ try:
+ import taichi as taichi # noqa
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(taichi_install_info)
+ finally:
+ sys.stdout = old_stdout
+
+ if taichi.__version__ != _minimal_taichi_version:
+ raise RuntimeError(taichi_install_info)
return taichi
@@ -82,6 +85,3 @@ def import_brainpylib_gpu_ops():
'See https://brainpy.readthedocs.io for installation instructions.')
return brainpylib_gpu_ops
-
-
-
diff --git a/brainpy/_src/deprecations.py b/brainpy/_src/deprecations.py
index b426aab8a..74a0103da 100644
--- a/brainpy/_src/deprecations.py
+++ b/brainpy/_src/deprecations.py
@@ -41,7 +41,6 @@ def f_input_or_monitor():
'''
-
def _deprecate(msg):
warnings.simplefilter('always', DeprecationWarning) # turn off filter
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
@@ -61,21 +60,25 @@ def new_func(*args, **kwargs):
return new_func
-def deprecation_getattr(module, deprecations):
- def getattr(name):
+def deprecation_getattr(module, deprecations, redirects=None, redirect_module=None):
+ redirects = redirects or {}
+
+ def get_attr(name):
if name in deprecations:
message, fn = deprecations[name]
if fn is None:
raise AttributeError(message)
_deprecate(message)
return fn
+ if name in redirects:
+ return getattr(redirect_module, name)
raise AttributeError(f"module {module!r} has no attribute {name!r}")
- return getattr
+ return get_attr
def deprecation_getattr2(module, deprecations):
- def getattr(name):
+ def get_attr(name):
if name in deprecations:
old_name, new_name, fn = deprecations[name]
message = f"{old_name} is deprecated. "
@@ -87,4 +90,4 @@ def getattr(name):
return fn
raise AttributeError(f"module {module!r} has no attribute {name!r}")
- return getattr
+ return get_attr
diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py
index 1073c7ec8..84b7e4009 100644
--- a/brainpy/_src/dnn/activations.py
+++ b/brainpy/_src/dnn/activations.py
@@ -840,10 +840,10 @@ class Softplus(Layer):
>>> output = m(input)
"""
__constants__ = ['beta', 'threshold']
- beta: int
- threshold: int
+ beta: float
+ threshold: float
- def __init__(self, beta: int = 1, threshold: int = 20) -> None:
+ def __init__(self, beta: float = 1, threshold: float = 20.) -> None:
super().__init__()
self.beta = beta
self.threshold = threshold
diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py
index 75b6373c5..e4b6e25d2 100644
--- a/brainpy/_src/dnn/conv.py
+++ b/brainpy/_src/dnn/conv.py
@@ -4,10 +4,10 @@
from jax import lax
-from brainpy import math as bm, tools, check
+from brainpy import math as bm, tools
+from brainpy._src.dnn.base import Layer
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
-from brainpy._src.dnn.base import Layer
__all__ = [
'Conv1d', 'Conv2d', 'Conv3d',
@@ -160,7 +160,7 @@ def update(self, x):
nonbatching = False
if x.ndim == self.num_spatial_dims + 1:
nonbatching = True
- x = x.unsqueeze(0)
+ x = bm.unsqueeze(x, 0)
w = self.w.value
if self.mask is not None:
try:
@@ -190,6 +190,9 @@ def __repr__(self):
class Conv1d(_GeneralConv):
"""One-dimensional convolution.
+ The input should a 2d array with the shape of ``[H, C]``, or
+ a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size.
+
Parameters
----------
in_channels: int
@@ -282,6 +285,9 @@ def _check_input_dim(self, x):
class Conv2d(_GeneralConv):
"""Two-dimensional convolution.
+ The input should a 3d array with the shape of ``[H, W, C]``, or
+ a 4d array with the shape of ``[B, H, W, C]``.
+
Parameters
----------
in_channels: int
@@ -375,6 +381,9 @@ def _check_input_dim(self, x):
class Conv3d(_GeneralConv):
"""Three-dimensional convolution.
+ The input should a 3d array with the shape of ``[H, W, D, C]``, or
+ a 4d array with the shape of ``[B, H, W, D, C]``.
+
Parameters
----------
in_channels: int
@@ -488,9 +497,7 @@ def __init__(
mode: bm.Mode = None,
name: str = None,
):
- super(_GeneralConvTranspose, self).__init__(name=name, mode=mode)
-
- assert self.mode.is_parent_of(bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode)
+ super().__init__(name=name, mode=mode)
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
@@ -586,22 +593,17 @@ def __init__(
"""Initializes the module.
Args:
- output_channels: Number of output channels.
- kernel_shape: The shape of the kernel. Either an integer or a sequence of
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ kernel_size: The shape of the kernel. Either an integer or a sequence of
length 1.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 1. Defaults to 1.
- output_shape: Output shape of the spatial dimensions of a transpose
- convolution. Can be either an integer or an iterable of integers. If a
- `None` value is given, a default shape is automatically calculated.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME``.
Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
- with_bias: Whether to add a bias. By default, true.
- w_init: Optional weight initialization. By default, truncated normal.
- b_init: Optional bias initialization. By default, zeros.
- data_format: The data format of the input. Either ``NWC`` or ``NCW``. By
- default, ``NWC``.
+ w_initializer: Optional weight initialization. By default, truncated normal.
+ b_initializer: Optional bias initialization. By default, zeros.
mask: Optional mask of the weights.
name: The name of the module.
"""
@@ -648,6 +650,7 @@ def __init__(
"""Initializes the module.
Args:
+ in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The shape of the kernel. Either an integer or a sequence of
length 2.
@@ -704,22 +707,17 @@ def __init__(
"""Initializes the module.
Args:
- output_channels: Number of output channels.
- kernel_shape: The shape of the kernel. Either an integer or a sequence of
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ kernel_size: The shape of the kernel. Either an integer or a sequence of
length 3.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 3. Defaults to 1.
- output_shape: Output shape of the spatial dimensions of a transpose
- convolution. Can be either an integer or an iterable of integers. If a
- `None` value is given, a default shape is automatically calculated.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME``.
Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
- with_bias: Whether to add a bias. By default, true.
- w_init: Optional weight initialization. By default, truncated normal.
- b_init: Optional bias initialization. By default, zeros.
- data_format: The data format of the input. Either ``NDHWC`` or ``NCDHW``.
- By default, ``NDHWC``.
+ w_initializer: Optional weight initialization. By default, truncated normal.
+ b_initializer: Optional bias initialization. By default, zeros.
mask: Optional mask of the weights.
name: The name of the module.
"""
diff --git a/brainpy/_src/dnn/function.py b/brainpy/_src/dnn/function.py
index 228dd7803..5f33552ed 100644
--- a/brainpy/_src/dnn/function.py
+++ b/brainpy/_src/dnn/function.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
-from typing import Callable
-from typing import Optional
+from typing import Callable, Optional, Sequence
import brainpy.math as bm
from brainpy._src.dnn.base import Layer
@@ -9,6 +8,7 @@
__all__ = [
'Activation',
'Flatten',
+ 'Unflatten',
'FunAsLayer',
]
@@ -43,28 +43,121 @@ def update(self, *args, **kwargs):
class Flatten(Layer):
- r"""Flattens a contiguous range of dims into 2D or 1D.
-
- Parameters:
- ----------
- name: str, Optional
- The name of the object
- mode: Mode
- Enable training this node or not. (default True)
+ r"""
+ Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
+
+ Shape:
+ - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
+ where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
+ number of dimensions including none.
+ - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
+
+ Args:
+ start_dim: first dim to flatten (default = 1).
+ end_dim: last dim to flatten (default = -1).
+ name: str, Optional. The name of the object.
+ mode: Mode. Enable training this node or not. (default True).
+
+ Examples::
+ >>> import brainpy.math as bm
+ >>> inp = bm.random.randn(32, 1, 5, 5)
+ >>> # With default parameters
+ >>> m = Flatten()
+ >>> output = m(inp)
+ >>> output.shape
+ (32, 25)
+ >>> # With non-default parameters
+ >>> m = Flatten(0, 2)
+ >>> output = m(inp)
+ >>> output.shape
+ (160, 5)
"""
def __init__(
self,
+ start_dim: int = 0,
+ end_dim: int = -1,
name: Optional[str] = None,
mode: bm.Mode = None,
):
super().__init__(name, mode)
+ self.start_dim = start_dim
+ self.end_dim = end_dim
+
def update(self, x):
- if isinstance(self.mode, bm.BatchingMode):
- return x.reshape((x.shape[0], -1))
+ if self.mode.is_child_of(bm.BatchingMode):
+ start_dim = (self.start_dim + 1) if self.start_dim >= 0 else (x.ndim + self.start_dim + 1)
else:
- return x.flatten()
+ start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim
+ return bm.flatten(x, start_dim, self.end_dim)
+
+ def __repr__(self) -> str:
+ return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})'
+
+
+class Unflatten(Layer):
+ r"""
+ Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
+
+ * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
+ be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
+
+ * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
+ a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
+ (tuple of `(name, size)` tuples) for `NamedTensor` input.
+
+ Shape:
+ - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
+ dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
+ - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
+ :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
+
+ Args:
+ dim: int, Dimension to be unflattened.
+ sizes: Sequence of int. New shape of the unflattened dimension.
+
+ Examples:
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> input = bm.random.randn(2, 50)
+ >>> # With tuple of ints
+ >>> m = bp.Sequential(
+ >>> bp.dnn.Linear(50, 50),
+ >>> Unflatten(1, (2, 5, 5))
+ >>> )
+ >>> output = m(input)
+ >>> output.shape
+ (2, 2, 5, 5)
+ >>> # With torch.Size
+ >>> m = bp.Sequential(
+ >>> bp.dnn.Linear(50, 50),
+ >>> Unflatten(1, [2, 5, 5])
+ >>> )
+ >>> output = m(input)
+ >>> output.shape
+ (2, 2, 5, 5)
+ """
+
+ def __init__(self, dim: int, sizes: Sequence[int], mode: bm.Mode = None, name: str = None) -> None:
+ super().__init__(mode=mode, name=name)
+
+ self.dim = dim
+ self.sizes = sizes
+ if isinstance(sizes, (tuple, list)):
+ for idx, elem in enumerate(sizes):
+ if not isinstance(elem, int):
+ raise TypeError("unflattened_size must be tuple of ints, " +
+ "but found element of type {} at pos {}".format(type(elem).__name__, idx))
+ else:
+ raise TypeError("unflattened_size must be tuple or list, but found type {}".format(type(sizes).__name__))
+
+ def update(self, x):
+ dim = self.dim + 1 if self.mode.is_batch_mode() else self.dim
+ return bm.unflatten(x, dim, self.sizes)
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})'
class FunAsLayer(Layer):
diff --git a/brainpy/_src/dnn/interoperation_flax.py b/brainpy/_src/dnn/interoperation_flax.py
index 09f03ac13..9804ac3bb 100644
--- a/brainpy/_src/dnn/interoperation_flax.py
+++ b/brainpy/_src/dnn/interoperation_flax.py
@@ -86,7 +86,7 @@ def initialize_carry(self, rng, batch_dims, size=None, init_fn=None):
raise NotImplementedError
_state_vars = self.model.vars().unique().not_subset(bm.TrainVar)
- self.model.reset_state(batch_size=batch_dims)
+ self.model.reset(batch_size=batch_dims)
return [_state_vars.dict(), 0, 0.]
def setup(self):
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index 09bf2958d..539214d3b 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -1,1280 +1,1423 @@
-# -*- coding: utf-8 -*-
-
-
-import numbers
-from typing import Dict, Optional, Union, Callable
-
-import jax
-import jax.numpy as jnp
-import numba
-import numpy as np
-
-from brainpy import math as bm
-from brainpy._src import connect, initialize as init
-from brainpy._src.context import share
-from brainpy._src.dnn.base import Layer
-from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
-from brainpy.check import is_initializer
-from brainpy.connect import csr2csc
-from brainpy.errors import MathError
-from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
-from brainpy.types import ArrayType, Sharding
-
-__all__ = [
- 'Dense', 'Linear',
- 'Identity',
- 'AllToAll',
- 'OneToOne',
- 'MaskedLinear',
- 'CSRLinear', 'EventCSRLinear',
- 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear',
- 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear',
-]
-
-
-class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline):
- r"""A linear transformation applied over the last dimension of the input.
-
- Mathematically, this node can be defined as:
-
- .. math::
-
- y = x \cdot weight + b
-
- Parameters
- ----------
- num_in: int
- The number of the input feature. A positive integer.
- num_out: int
- The number of the output features. A positive integer.
- W_initializer: optional, Initializer
- The weight initialization.
- b_initializer: optional, Initializer
- The bias initialization.
- mode: Mode
- Enable training this node or not. (default True)
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
- b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super(Dense, self).__init__(mode=mode, name=name)
-
- # shape
- self.num_in = num_in
- self.num_out = num_out
- if num_in < 0:
- raise ValueError(f'Received an invalid value for `num_out`, expected '
- f'a positive integer. Received: num_in={num_in}')
- if num_out < 0:
- raise ValueError(f'Received an invalid value for `num_out`, expected '
- f'a positive integer. Received: num_out={num_out}')
-
- # weight initializer
- self.W_initializer = W_initializer
- self.bias_initializer = b_initializer
- is_initializer(W_initializer, 'weight_initializer')
- is_initializer(b_initializer, 'bias_initializer', allow_none=True)
-
- # parameter initialization
- W = parameter(self.W_initializer, (num_in, self.num_out))
- b = parameter(self.bias_initializer, (self.num_out,))
- if isinstance(self.mode, bm.TrainingMode):
- W = bm.TrainVar(W)
- b = None if (b is None) else bm.TrainVar(b)
- self.W = W
- self.b = b
-
- # fitting parameters
- self.online_fit_by = None # support online training
- self.offline_fit_by = None # support offline training
- self.fit_record = dict()
-
- def __repr__(self):
- return (f'{self.__class__.__name__}(name={self.name}, '
- f'num_in={self.num_in}, '
- f'num_out={self.num_out}, '
- f'mode={self.mode})')
-
- def update(self, x):
- x = bm.as_jax(x)
- res = x @ self.W
- if self.b is not None:
- res += self.b
-
- # online fitting data
- if share.load('fit', False) and self.online_fit_by is not None:
- self.fit_record['input'] = x
- self.fit_record['output'] = res
-
- # offline fitting data
- if share.load('fit', False) and self.offline_fit_by is not None:
- self.fit_record['input'] = x
- self.fit_record['output'] = res
- return res
-
- def online_init(self):
- if self.b is None:
- num_input = self.num_in
- else:
- num_input = self.num_in + 1
- self.online_fit_by.register_target(feature_in=num_input, identifier=self.name)
-
- def online_fit(self,
- target: ArrayType,
- fit_record: Dict[str, ArrayType]):
- if not isinstance(target, (bm.ndarray, jnp.ndarray)):
- raise MathError(f'"target" must be a tensor, but got {type(target)}')
- x = fit_record['input']
- y = fit_record['output']
- if x.ndim != 2:
- raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, '
- f'num_feature), but we got {x.shape}')
- if target.ndim != 2:
- raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, '
- f'num_feature), but we got {target.shape}')
- if x.shape[0] != target.shape[0]:
- raise ValueError(f'Batch size of the input and target data should be '
- f'the same, while we got {x.shape[0]} != {target.shape[0]}.')
- if target.shape[1] != y.shape[1]:
- raise MathError(f'The output dimension of output and target data should be '
- f'the same, while we got {target.shape[1]} != {y.shape[1]}')
-
- # data
- if self.b is not None:
- x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1)
-
- # fitting
- dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name)
-
- # assign trained weights
- if self.b is None:
- self.W += dW
- else:
- db, dW = jnp.split(dW, [1])
- self.b += db[0]
- self.W += dW
-
- def offline_fit(self,
- target: ArrayType,
- fit_record: Dict[str, ArrayType]):
- """The offline training interface for the Dense node."""
- # data checking
- if not isinstance(target, (bm.ndarray, jnp.ndarray)):
- raise MathError(f'"targets" must be a tensor, but got {type(target)}')
- xs = fit_record['input']
- ys = fit_record['output']
- if xs.ndim != 3:
- raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, '
- f'num_feature), but we got {xs.shape}')
- if target.ndim != 3:
- raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, '
- f'num_feature), but we got {target.shape}')
- if ys.shape != target.shape:
- raise ValueError(f'The shapes of output and target data should be '
- f'the same, while we got {ys.shape} != {target.shape}.')
- if xs.shape[0] != target.shape[0]:
- raise ValueError(f'Batch size of the input and target data should be '
- f'the same, while we got {xs.shape[0]} != {target.shape[0]}.')
- if xs.shape[1] != target.shape[1]:
- raise MathError(f'The time dimension of input and target data should be '
- f'the same, while we got {xs.shape[1]} != {target.shape[1]}')
-
- # get input and target training data
- if self.b is not None:
- xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input)
-
- # solve weights by offline training methods
- weights = self.offline_fit_by(target, xs, ys)
-
- # assign trained weights
- if self.b is None:
- self.W.value = weights
- else:
- bias, Wff = jnp.split(weights, [1])
- self.W.value = Wff
- self.b.value = bias[0]
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.W, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.W, bm.Variable):
- self.tracing_variable('W', self.W, self.W.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max)
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max)
-
-
-Linear = Dense
-
-
-class Identity(Layer):
- r"""A placeholder identity operator that is argument-insensitive.
- """
-
- def __init__(self, *args, **kwargs) -> None:
- super(Identity, self).__init__(*args, **kwargs)
-
- def update(self, x):
- return x
-
-
-@numba.njit(nogil=True, fastmath=True, parallel=False)
-def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w):
- out_w[:] = weight
- for i in numba.prange(spike.shape[0]):
- if spike[i]:
- out_w[i] = np.clip(out_w[i] + trace, w_min, w_max)
-
-
-dense_on_pre_prim = bm.XLACustomOp(_cpu_dense_on_pre)
-
-
-def dense_on_pre(weight, spike, trace, w_min, w_max):
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- return dense_on_pre_prim(weight, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
-
-
-@numba.njit(nogil=True, fastmath=True, parallel=False)
-def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
- out_w[:] = weight
- for i in numba.prange(spike.shape[0]):
- if spike[i]:
- out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max)
-
-
-dense_on_post_prim = bm.XLACustomOp(_cpu_dense_on_post)
-
-
-def dense_on_post(weight, spike, trace, w_min, w_max):
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- return dense_on_post_prim(weight, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
-
-
-class AllToAll(Layer, SupportSTDP):
- """Synaptic matrix multiplication with All2All connections.
-
- Args:
- num_pre: int. The number of neurons in the presynaptic neuron group.
- num_post: int. The number of neurons in the postsynaptic neuron group.
- weight: The synaptic weights.
- sharding: The sharding strategy.
- include_self: bool. Whether connect the neuron with at the same position.
- mode: Mode. The computing mode.
- name: str. The object name.
- """
-
- def __init__(
- self,
- num_pre: int,
- num_post: int,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- include_self: bool = True,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(mode=mode, name=name)
-
- self.num_pre = num_pre
- self.num_post = num_post
- self.include_self = include_self
- self.sharding = sharding
-
- weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding)
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, pre_val):
- if bm.ndim(self.weight) == 0: # weight is a scalar
- if isinstance(self.mode, bm.BatchingMode):
- assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.'
- post_val = bm.sum(pre_val, keepdims=True, axis=1)
- else:
- assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.'
- post_val = bm.sum(pre_val)
- if not self.include_self:
- if self.num_pre == self.num_post:
- post_val = post_val - pre_val
- elif self.num_pre > self.num_post:
- val = pre_val[:self.num_post]
- post_val = post_val - val
- else:
- val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)])
- post_val = post_val - val
- post_val = self.weight * post_val
-
- else: # weight is a matrix
- assert self.weight.ndim == 2, '"weight" must be a 2D matrix.'
- if not self.include_self:
- post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False)
- else:
- post_val = pre_val @ self.weight
- return post_val
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.weight, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
-
-
-class OneToOne(Layer, SupportSTDP):
- """Synaptic matrix multiplication with One2One connection.
-
- Args:
- num: int. The number of neurons.
- weight: The synaptic weight.
- sharding: The sharding strategy.
- mode: The computing mode.
- name: The object name.
-
- """
-
- def __init__(
- self,
- num: int,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(mode=mode, name=name)
-
- self.num = num
- self.sharding = sharding
-
- weight = init.parameter(weight, (self.num,), sharding=sharding)
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, pre_val):
- return pre_val * self.weight
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.weight, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value += spike * trace
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value += spike * trace
-
-
-class MaskedLinear(Layer, SupportSTDP):
- r"""Synaptic matrix multiplication with masked dense computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a dense matrix.
-
- >>> import brainpy as bp
- >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100),
- >>> weight=0.1)
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- mask_fun: Masking function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- mask_fun: Callable = Identity(),
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
- self.mask_fun = mask_fun
-
- # weight
- weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding)
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- # connection
- self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding)
-
- def update(self, x):
- return x @ self.mask_fun(self.weight * self.mask)
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.weight, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
-
-
-class _CSRLayer(Layer, SupportSTDP):
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = True,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- assert sharding is None, 'Currently this model does not support sharding.'
- self.conn = conn
- self.sharding = sharding
- self.transpose = transpose
-
- # connection
- self.indices, self.indptr = self.conn.require('csr')
-
- # weight
- weight = init.parameter(weight, (self.indices.size,))
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if bm.isscalar(self.weight):
- raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
- if self.weight.shape != self.indices.shape:
- raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None: # update on presynaptic spike
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max)
- if on_post is not None: # update on postsynaptic spike
- if not hasattr(self, '_pre_ids'):
- with jax.ensure_compile_time_eval():
- self._pre_ids, self._post_indptr, self.w_indices = csr2csc(
- [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size)
- )
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr,
- self.w_indices, spike, trace, w_min, w_max)
-
-
-class CSRLinear(_CSRLayer):
- r"""Synaptic matrix multiplication with CSR sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a CSR sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- method: str = 'cusparse',
- transpose: bool = True,
- ):
- super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
- self.method = method
-
- def update(self, x):
- if x.ndim == 1:
- return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num),
- transpose=self.transpose,
- method=self.method)
- elif x.ndim > 1:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_csrmv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_csrmv(self, x):
- return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num),
- transpose=self.transpose,
- method=self.method)
-
-
-class EventCSRLinear(_CSRLayer):
- r"""Synaptic matrix multiplication with event CSR sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weight using a CSR sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = True,
- ):
- super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
-
- def update(self, x):
- if x.ndim == 1:
- return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num),
- transpose=self.transpose)
- elif x.ndim > 1:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_csrmv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_csrmv(self, x):
- return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num),
- transpose=self.transpose)
-
-
-@numba.njit(nogil=True, fastmath=True, parallel=False)
-def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w):
- out_w[:] = w
- w_min = w_min[()]
- w_max = w_max[()]
- for i in numba.prange(spike.shape[0]): # pre id
- if spike[i]:
- for k in range(indptr[i], indptr[i + 1]): # synapse id
- j = indices[k] # post id
- # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max)
- out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max)
-
-
-csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update)
-
-
-def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
-
-
-@numba.njit(nogil=True, fastmath=True, parallel=False)
-def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
- out_w[:] = w
- w_min = w_min[()]
- w_max = w_max[()]
- for i in numba.prange(spike.shape[0]): # post id
- if spike[i]:
- for k in range(indptr[i], indptr[i + 1]):
- j = post_ids[k] # pre id
- l = w_ids[k] # syn id
- out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max)
-
-
-csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update)
-
-
-def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None):
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
-
-
-class CSCLinear(Layer):
- r"""Synaptic matrix multiplication with CSC sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a CSC sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
-
-
-class BcsrMM(Layer):
- r"""Synaptic matrix multiplication with BCSR sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a BCSR sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
-
-
-class BcscMM(Layer):
- r"""Synaptic matrix multiplication with BCSC sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a BCSC sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
-
-
-class JitFPHomoLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is the same :math:`weight`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- weight: float. The synaptic value at each position.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- weight: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = False,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class JitFPUniformLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_low: float. The lowest value of the uniform distribution.
- w_high: float. The highest value of the uniform distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_low: float,
- w_high: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = False,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_low = w_low
- self.w_high = w_high
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class JitFPNormalLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_mu: float. The center of the normal distribution.
- w_sigma: float. The standard variance of the normal distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_mu: float,
- w_sigma: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- transpose: bool = False,
- atomic: bool = False,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_mu = w_mu
- self.w_sigma = w_sigma
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class EventJitFPHomoLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is the same :math:`weight`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- weight: float. The synaptic value at each position.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- weight: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = False,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 1000000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class EventJitFPUniformLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_low: float. The lowest value of the uniform distribution.
- w_high: float. The highest value of the uniform distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_low: float,
- w_high: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = False,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_low = w_low
- self.w_high = w_high
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class EventJitFPNormalLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_mu: float. The center of the normal distribution.
- w_sigma: float. The standard variance of the normal distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_mu: float,
- w_sigma: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- transpose: bool = False,
- atomic: bool = False,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_mu = w_mu
- self.w_sigma = w_sigma
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
+# -*- coding: utf-8 -*-
+
+
+import numbers
+from typing import Dict, Optional, Union, Callable
+
+import jax
+import jax.numpy as jnp
+import numba
+import numpy as np
+
+from brainpy import math as bm
+from brainpy._src import connect, initialize as init
+from brainpy._src.context import share
+from brainpy._src.dnn.base import Layer
+from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
+from brainpy._src.dependency_check import import_taichi
+from brainpy.check import is_initializer
+from brainpy.connect import csr2csc
+from brainpy.errors import MathError
+from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
+from brainpy.types import ArrayType, Sharding
+
+ti = import_taichi()
+
+__all__ = [
+ 'Dense', 'Linear',
+ 'Identity',
+ 'AllToAll',
+ 'OneToOne',
+ 'MaskedLinear',
+ 'CSRLinear', 'EventCSRLinear',
+ 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear',
+ 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear',
+]
+
+
+class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline):
+ r"""A linear transformation applied over the last dimension of the input.
+
+ Mathematically, this node can be defined as:
+
+ .. math::
+
+ y = x \cdot weight + b
+
+ Parameters
+ ----------
+ num_in: int
+ The number of the input feature. A positive integer.
+ num_out: int
+ The number of the output features. A positive integer.
+ W_initializer: optional, Initializer
+ The weight initialization.
+ b_initializer: optional, Initializer
+ The bias initialization.
+ mode: Mode
+ Enable training this node or not. (default True)
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
+ b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super(Dense, self).__init__(mode=mode, name=name)
+
+ # shape
+ self.num_in = num_in
+ self.num_out = num_out
+ if num_in < 0:
+ raise ValueError(f'Received an invalid value for `num_out`, expected '
+ f'a positive integer. Received: num_in={num_in}')
+ if num_out < 0:
+ raise ValueError(f'Received an invalid value for `num_out`, expected '
+ f'a positive integer. Received: num_out={num_out}')
+
+ # weight initializer
+ self.W_initializer = W_initializer
+ self.bias_initializer = b_initializer
+ is_initializer(W_initializer, 'weight_initializer')
+ is_initializer(b_initializer, 'bias_initializer', allow_none=True)
+
+ # parameter initialization
+ W = parameter(self.W_initializer, (num_in, self.num_out))
+ b = parameter(self.bias_initializer, (self.num_out,))
+ if isinstance(self.mode, bm.TrainingMode):
+ W = bm.TrainVar(W)
+ b = None if (b is None) else bm.TrainVar(b)
+ self.W = W
+ self.b = b
+
+ # fitting parameters
+ self.online_fit_by = None # support online training
+ self.offline_fit_by = None # support offline training
+ self.fit_record = dict()
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(name={self.name}, '
+ f'num_in={self.num_in}, '
+ f'num_out={self.num_out}, '
+ f'mode={self.mode})')
+
+ def update(self, x):
+ x = bm.as_jax(x)
+ res = x @ self.W
+ if self.b is not None:
+ res += self.b
+
+ # online fitting data
+ if share.load('fit', False) and self.online_fit_by is not None:
+ self.fit_record['input'] = x
+ self.fit_record['output'] = res
+
+ # offline fitting data
+ if share.load('fit', False) and self.offline_fit_by is not None:
+ self.fit_record['input'] = x
+ self.fit_record['output'] = res
+ return res
+
+ def online_init(self):
+ if self.b is None:
+ num_input = self.num_in
+ else:
+ num_input = self.num_in + 1
+ self.online_fit_by.register_target(feature_in=num_input, identifier=self.name)
+
+ def online_fit(self,
+ target: ArrayType,
+ fit_record: Dict[str, ArrayType]):
+ if not isinstance(target, (bm.ndarray, jnp.ndarray)):
+ raise MathError(f'"target" must be a tensor, but got {type(target)}')
+ x = fit_record['input']
+ y = fit_record['output']
+ if x.ndim != 2:
+ raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, '
+ f'num_feature), but we got {x.shape}')
+ if target.ndim != 2:
+ raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, '
+ f'num_feature), but we got {target.shape}')
+ if x.shape[0] != target.shape[0]:
+ raise ValueError(f'Batch size of the input and target data should be '
+ f'the same, while we got {x.shape[0]} != {target.shape[0]}.')
+ if target.shape[1] != y.shape[1]:
+ raise MathError(f'The output dimension of output and target data should be '
+ f'the same, while we got {target.shape[1]} != {y.shape[1]}')
+
+ # data
+ if self.b is not None:
+ x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1)
+
+ # fitting
+ dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name)
+
+ # assign trained weights
+ if self.b is None:
+ self.W += dW
+ else:
+ db, dW = jnp.split(dW, [1])
+ self.b += db[0]
+ self.W += dW
+
+ def offline_fit(self,
+ target: ArrayType,
+ fit_record: Dict[str, ArrayType]):
+ """The offline training interface for the Dense node."""
+ # data checking
+ if not isinstance(target, (bm.ndarray, jnp.ndarray)):
+ raise MathError(f'"targets" must be a tensor, but got {type(target)}')
+ xs = fit_record['input']
+ ys = fit_record['output']
+ if xs.ndim != 3:
+ raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, '
+ f'num_feature), but we got {xs.shape}')
+ if target.ndim != 3:
+ raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, '
+ f'num_feature), but we got {target.shape}')
+ if ys.shape != target.shape:
+ raise ValueError(f'The shapes of output and target data should be '
+ f'the same, while we got {ys.shape} != {target.shape}.')
+ if xs.shape[0] != target.shape[0]:
+ raise ValueError(f'Batch size of the input and target data should be '
+ f'the same, while we got {xs.shape[0]} != {target.shape[0]}.')
+ if xs.shape[1] != target.shape[1]:
+ raise MathError(f'The time dimension of input and target data should be '
+ f'the same, while we got {xs.shape[1]} != {target.shape[1]}')
+
+ # get input and target training data
+ if self.b is not None:
+ xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input)
+
+ # solve weights by offline training methods
+ weights = self.offline_fit_by(target, xs, ys)
+
+ # assign trained weights
+ if self.b is None:
+ self.W.value = weights
+ else:
+ bias, Wff = jnp.split(weights, [1])
+ self.W.value = Wff
+ self.b.value = bias[0]
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.W, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.W, bm.Variable):
+ self.tracing_variable('W', self.W, self.W.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max)
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max)
+
+
+Linear = Dense
+
+
+class Identity(Layer):
+ r"""A placeholder identity operator that is argument-insensitive.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super(Identity, self).__init__(*args, **kwargs)
+
+ def update(self, x):
+ return x
+
+
+# @numba.njit(nogil=True, fastmath=True, parallel=False)
+# def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w):
+# out_w[:] = weight
+# for i in numba.prange(spike.shape[0]):
+# if spike[i]:
+# out_w[i] = np.clip(out_w[i] + trace, w_min, w_max)
+
+@ti.kernel
+def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2),
+ spike: ti.types.ndarray(ndim=1),
+ trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=2)):
+ trace0 = trace[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]):
+ out_w[i, j] = weight[i, j]
+ for i in range(spike.shape[0]):
+ if spike[i]:
+ for j in range(out_w.shape[1]):
+ new_value = out_w[i, j] + trace0
+ if new_value < w_min0:
+ out_w[i, j] = w_min0
+ elif new_value > w_max0:
+ out_w[i, j] = w_max0
+ else:
+ out_w[i, j] = new_value
+
+
+@ti.kernel
+def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1),
+ spike: ti.types.ndarray(ndim=1),
+ trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=1)):
+ trace0 = trace[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]):
+ out_w[i, j] = weight[i, j]
+ for i in range(spike.shape[0]):
+ if spike[i]:
+ for j in range(out_w.shape[1]):
+ new_value = out_w[i, j] + trace0
+ if new_value < w_min0:
+ out_w[i, j] = w_min0
+ elif new_value > w_max0:
+ out_w[i, j] = w_max0
+ else:
+ out_w[i, j] = new_value
+
+
+dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre,
+ gpu_kernel=_gpu_dense_on_pre)
+
+
+def dense_on_pre(weight, spike, trace, w_min, w_max):
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ trace = jnp.atleast_1d(trace)
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
+ return dense_on_pre_prim(weight, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
+
+
+# @numba.njit(nogil=True, fastmath=True, parallel=False)
+# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
+# out_w[:] = weight
+# for i in numba.prange(spike.shape[0]):
+# if spike[i]:
+# out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max)
+
+@ti.kernel
+def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2),
+ spike: ti.types.ndarray(ndim=1),
+ trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=2)):
+ trace0 = trace[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]):
+ out_w[i, j] = weight[i, j]
+ for i in range(spike.shape[0]):
+ if spike[i]:
+ for j in range(out_w.shape[0]):
+ new_value = out_w[j, i] + trace0
+ if new_value < w_min0:
+ out_w[j, i] = w_min0
+ elif new_value > w_max0:
+ out_w[j, i] = w_max0
+ else:
+ out_w[j, i] = new_value
+
+@ti.kernel
+def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2),
+ spike: ti.types.ndarray(ndim=1),
+ trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=2)):
+ trace0 = trace[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]):
+ out_w[i, j] = weight[i, j]
+ for i in range(spike.shape[0]):
+ if spike[i]:
+ for j in range(out_w.shape[0]):
+ new_value = out_w[j, i] + trace0
+ if new_value < w_min0:
+ out_w[j, i] = w_min0
+ elif new_value > w_max0:
+ out_w[j, i] = w_max0
+ else:
+ out_w[j, i] = new_value
+
+dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post,
+ gpu_kernel=_gpu_dense_on_post)
+
+
+def dense_on_post(weight, spike, trace, w_min, w_max):
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ trace = jnp.atleast_1d(trace)
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
+ return dense_on_post_prim(weight, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
+
+
+class AllToAll(Layer, SupportSTDP):
+ """Synaptic matrix multiplication with All2All connections.
+
+ Args:
+ num_pre: int. The number of neurons in the presynaptic neuron group.
+ num_post: int. The number of neurons in the postsynaptic neuron group.
+ weight: The synaptic weights.
+ sharding: The sharding strategy.
+ include_self: bool. Whether connect the neuron with at the same position.
+ mode: Mode. The computing mode.
+ name: str. The object name.
+ """
+
+ def __init__(
+ self,
+ num_pre: int,
+ num_post: int,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ include_self: bool = True,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(mode=mode, name=name)
+
+ self.num_pre = num_pre
+ self.num_post = num_post
+ self.include_self = include_self
+ self.sharding = sharding
+
+ weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding)
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, pre_val):
+ if bm.ndim(self.weight) == 0: # weight is a scalar
+ if isinstance(self.mode, bm.BatchingMode):
+ assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.'
+ post_val = bm.sum(pre_val, keepdims=True, axis=1)
+ else:
+ assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.'
+ post_val = bm.sum(pre_val)
+ if not self.include_self:
+ if self.num_pre == self.num_post:
+ post_val = post_val - pre_val
+ elif self.num_pre > self.num_post:
+ val = pre_val[:self.num_post]
+ post_val = post_val - val
+ else:
+ val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)])
+ post_val = post_val - val
+ post_val = self.weight * post_val
+
+ else: # weight is a matrix
+ assert self.weight.ndim == 2, '"weight" must be a 2D matrix.'
+ if not self.include_self:
+ post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False)
+ else:
+ post_val = pre_val @ self.weight
+ return post_val
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
+
+
+class OneToOne(Layer, SupportSTDP):
+ """Synaptic matrix multiplication with One2One connection.
+
+ Args:
+ num: int. The number of neurons.
+ weight: The synaptic weight.
+ sharding: The sharding strategy.
+ mode: The computing mode.
+ name: The object name.
+
+ """
+
+ def __init__(
+ self,
+ num: int,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(mode=mode, name=name)
+
+ self.num = num
+ self.sharding = sharding
+
+ weight = init.parameter(weight, (self.num,), sharding=sharding)
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, pre_val):
+ return pre_val * self.weight
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value += spike * trace
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value += spike * trace
+
+
+class MaskedLinear(Layer, SupportSTDP):
+ r"""Synaptic matrix multiplication with masked dense computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a dense matrix.
+
+ >>> import brainpy as bp
+ >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100),
+ >>> weight=0.1)
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ mask_fun: Masking function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ mask_fun: Callable = Identity(),
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
+ self.mask_fun = mask_fun
+
+ # weight
+ weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding)
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ # connection
+ self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding)
+
+ def update(self, x):
+ return x @ self.mask_fun(self.weight * self.mask)
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
+
+
+class _CSRLayer(Layer, SupportSTDP):
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = True,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ assert sharding is None, 'Currently this model does not support sharding.'
+ self.conn = conn
+ self.sharding = sharding
+ self.transpose = transpose
+
+ # connection
+ self.indices, self.indptr = self.conn.require('csr')
+
+ # weight
+ weight = init.parameter(weight, (self.indices.size,))
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if bm.isscalar(self.weight):
+ raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
+ if self.weight.shape != self.indices.shape:
+ raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None: # update on presynaptic spike
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max)
+ if on_post is not None: # update on postsynaptic spike
+ if not hasattr(self, '_pre_ids'):
+ with jax.ensure_compile_time_eval():
+ self._pre_ids, self._post_indptr, self.w_indices = csr2csc(
+ [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size)
+ )
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr,
+ self.w_indices, spike, trace, w_min, w_max)
+
+
+class CSRLinear(_CSRLayer):
+ r"""Synaptic matrix multiplication with CSR sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a CSR sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ method: str = None,
+ transpose: bool = True,
+ ):
+ super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
+ self.method = method
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num),
+ method=self.method, transpose=self.transpose)
+ elif x.ndim > 1:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_csrmv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_csrmv(self, x):
+ return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num),
+ method=self.method, transpose=self.transpose)
+
+class EventCSRLinear(_CSRLayer):
+ r"""Synaptic matrix multiplication with event CSR sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weight using a CSR sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = True,
+ ):
+ super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num),
+ transpose=self.transpose)
+ elif x.ndim > 1:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_csrmv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_csrmv(self, x):
+ return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num),
+ transpose=self.transpose)
+
+# @numba.njit(nogil=True, fastmath=True, parallel=False)
+# def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w):
+# out_w[:] = w
+# w_min = w_min[()]
+# w_max = w_max[()]
+# for i in numba.prange(spike.shape[0]): # pre id
+# if spike[i]:
+# for k in range(indptr[i], indptr[i + 1]): # synapse id
+# j = indices[k] # post id
+# # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max)
+# out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max)
+
+
+@ti.kernel
+def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ spike: ti.types.ndarray(ndim=1),
+ trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=1)):
+ trace0 = trace[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ for i in range(out_w.shape[0]):
+ out_w[i] = w[i]
+ for i in range(spike.shape[0]):
+ if spike[i]:
+ for k in range(indptr[i], indptr[i + 1]):
+ j = indices[k]
+ out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0)
+@ti.kernel
+def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ spike: ti.types.ndarray(ndim=1),
+ trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=1)):
+ trace0 = trace[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ for i in range(out_w.shape[0]):
+ out_w[i] = w[i]
+ for i in range(spike.shape[0]):
+ if spike[i]:
+ for k in range(indptr[i], indptr[i + 1]):
+ j = indices[k]
+ out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0)
+
+
+csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update,
+ gpu_kernel=_gpu_csr_on_pre_update)
+
+
+def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ trace = jnp.atleast_1d(trace)
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
+ return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
+
+@numba.njit(nogil=True, fastmath=True, parallel=False)
+def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
+ out_w[:] = w
+ w_min = w_min[()]
+ w_max = w_max[()]
+ for i in numba.prange(spike.shape[0]): # post id
+ if spike[i]:
+ for k in range(indptr[i], indptr[i + 1]):
+ j = post_ids[k] # pre id
+ l = w_ids[k] # syn id
+ out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max)
+
+
+csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update)
+
+
+def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None):
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
+
+
+
+class CSCLinear(Layer):
+ r"""Synaptic matrix multiplication with CSC sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a CSC sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
+
+
+class BcsrMM(Layer):
+ r"""Synaptic matrix multiplication with BCSR sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a BCSR sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
+
+
+class BcscMM(Layer):
+ r"""Synaptic matrix multiplication with BCSC sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a BCSC sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
+
+
+class JitFPHomoLinear(Layer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is the same :math:`weight`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ weight: float. The synaptic value at each position.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ weight: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = False,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class JitFPUniformLinear(Layer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_low: float. The lowest value of the uniform distribution.
+ w_high: float. The highest value of the uniform distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_low: float,
+ w_high: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = False,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_low = w_low
+ self.w_high = w_high
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class JitFPNormalLinear(Layer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_mu: float. The center of the normal distribution.
+ w_sigma: float. The standard variance of the normal distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_mu: float,
+ w_sigma: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ transpose: bool = False,
+ atomic: bool = False,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_mu = w_mu
+ self.w_sigma = w_sigma
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class EventJitFPHomoLinear(Layer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is the same :math:`weight`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ weight: float. The synaptic value at each position.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ weight: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = True,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 1000000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class EventJitFPUniformLinear(Layer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_low: float. The lowest value of the uniform distribution.
+ w_high: float. The highest value of the uniform distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_low: float,
+ w_high: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = True,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_low = w_low
+ self.w_high = w_high
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class EventJitFPNormalLinear(Layer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_mu: float. The center of the normal distribution.
+ w_sigma: float. The standard variance of the normal distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_mu: float,
+ w_sigma: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ transpose: bool = False,
+ atomic: bool = True,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_mu = w_mu
+ self.w_sigma = w_sigma
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py
index ba2a49efd..17054667d 100644
--- a/brainpy/_src/dnn/tests/test_activation.py
+++ b/brainpy/_src/dnn/tests/test_activation.py
@@ -1,5 +1,6 @@
-from absl.testing import parameterized
from absl.testing import absltest
+from absl.testing import parameterized
+
import brainpy as bp
import brainpy.math as bm
diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py
index 3c9fdfa87..05f523622 100644
--- a/brainpy/_src/dnn/tests/test_conv_layers.py
+++ b/brainpy/_src/dnn/tests/test_conv_layers.py
@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
-from unittest import TestCase
-from absl.testing import absltest
import jax.numpy as jnp
-import brainpy.math as bm
+from absl.testing import absltest
from absl.testing import parameterized
+
import brainpy as bp
import brainpy.math as bm
class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
- bm.random.seed()
img = jnp.zeros((2, 200, 198, 4))
for k in range(4):
x = 30 + 60 * k
@@ -24,6 +22,7 @@ def test_Conv2D_img(self):
strides=(2, 1), padding='VALID', groups=4)
out = net(img)
print("out shape: ", out.shape)
+ self.assertEqual(out.shape, (2, 99, 196, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(img)[0, :, :, 0])
@@ -31,7 +30,6 @@ def test_Conv2D_img(self):
bm.clear_buffer_memory()
def test_conv1D(self):
- bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))
@@ -39,6 +37,7 @@ def test_conv1D(self):
out = model(input)
print("out shape: ", out.shape)
+ self.assertEqual(out.shape, (2, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :])
@@ -54,6 +53,7 @@ def test_conv2D(self):
out = model(input)
print("out shape: ", out.shape)
+ self.assertEqual(out.shape, (2, 5, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :, 31])
@@ -67,6 +67,7 @@ def test_conv3D(self):
input = bp.math.ones((2, 5, 5, 5, 3))
out = model(input)
print("out shape: ", out.shape)
+ self.assertEqual(out.shape, (2, 5, 5, 5, 32))
bm.clear_buffer_memory()
diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py
index a686d2a41..9ad15938d 100644
--- a/brainpy/_src/dnn/tests/test_function.py
+++ b/brainpy/_src/dnn/tests/test_function.py
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-
-from unittest import TestCase
-
-import jax.numpy as jnp
-import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized
+
import brainpy as bp
+import brainpy.math as bm
class TestFunction(parameterized.TestCase):
@@ -33,6 +31,15 @@ def test_flatten_non_batching_mode(self):
self.assertEqual(output.shape, expected_shape)
bm.clear_buffer_memory()
+ def test_unflatten(self):
+ bm.random.seed()
+ layer = bp.dnn.Unflatten(1, (10, 6), mode=bm.NonBatchingMode())
+ input = bm.random.randn(5, 60)
+ output = layer.update(input)
+ expected_shape = (5, 10, 6)
+ self.assertEqual(output.shape, expected_shape)
+ bm.clear_buffer_memory()
+
if __name__ == '__main__':
absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py
index da49bdbfe..df5293ab9 100644
--- a/brainpy/_src/dnn/tests/test_linear.py
+++ b/brainpy/_src/dnn/tests/test_linear.py
@@ -1,6 +1,7 @@
-import brainpy as bp
-from absl.testing import parameterized
from absl.testing import absltest
+from absl.testing import parameterized
+
+import brainpy as bp
import brainpy.math as bm
@@ -213,6 +214,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
self.assertTrue(y2.shape == shape + (200,))
bm.clear_buffer_memory()
-
if __name__ == '__main__':
absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py
index 0d754976f..3cf923d7b 100644
--- a/brainpy/_src/dnn/tests/test_mode.py
+++ b/brainpy/_src/dnn/tests/test_mode.py
@@ -1,7 +1,8 @@
-import brainpy.math as bm
-from absl.testing import parameterized
from absl.testing import absltest
+from absl.testing import parameterized
+
import brainpy as bp
+import brainpy.math as bm
class Test_Conv(parameterized.TestCase):
diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py
index fdc5b34e3..de2c9765b 100644
--- a/brainpy/_src/dnn/tests/test_normalization.py
+++ b/brainpy/_src/dnn/tests/test_normalization.py
@@ -1,7 +1,8 @@
-import brainpy.math as bm
-from absl.testing import parameterized
from absl.testing import absltest
+from absl.testing import parameterized
+
import brainpy as bp
+import brainpy.math as bm
class Test_Normalization(parameterized.TestCase):
diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py
index 34f8f5cd5..5748edd8b 100644
--- a/brainpy/_src/dnn/tests/test_pooling_layers.py
+++ b/brainpy/_src/dnn/tests/test_pooling_layers.py
@@ -3,8 +3,8 @@
import jax
import jax.numpy as jnp
import numpy as np
-from absl.testing import parameterized
from absl.testing import absltest
+from absl.testing import parameterized
import brainpy as bp
import brainpy.math as bm
diff --git a/brainpy/_src/dyn/_docs.py b/brainpy/_src/dyn/_docs.py
index c2c75ffc9..d528d4266 100644
--- a/brainpy/_src/dyn/_docs.py
+++ b/brainpy/_src/dyn/_docs.py
@@ -40,3 +40,166 @@
ltc_doc = 'with liquid time-constant'
+
+dual_exp_syn_doc = r'''
+
+ **Model Descriptions**
+
+ The dual exponential synapse model [1]_, also named as *difference of two exponentials* model,
+ is given by:
+
+ .. math::
+
+ g_{\mathrm{syn}}(t)=g_{\mathrm{max}} A \left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right)
+ -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right)
+
+ where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2`
+ is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic
+ spike, :math:`g_{\mathrm{max}}` is the maximal conductance.
+
+ However, in practice, this formula is hard to implement. The equivalent solution is
+ two coupled linear differential equations [2]_:
+
+ .. math::
+
+ \begin{aligned}
+ &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\
+ &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ (\frac{1}{\tau_{\text{rise}}} - \frac{1}{\tau_{\text{decay}}}) A \delta\left(t_{0}-t\right),
+ \end{aligned}
+
+ By default, :math:`A` has the following value:
+
+ .. math::
+
+ A = \frac{{\tau }_{decay}}{{\tau }_{decay}-{\tau }_{rise}}{\left(\frac{{\tau }_{rise}}{{\tau }_{decay}}\right)}^{\frac{{\tau }_{rise}}{{\tau }_{rise}-{\tau }_{decay}}}
+
+ .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
+ "The Synapse." Principles of Computational Modelling in Neuroscience.
+ Cambridge: Cambridge UP, 2011. 172-95. Print.
+ .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational
+ Modeling Methods for Neuroscientists.
+
+'''
+
+dual_exp_args = '''
+ tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms]
+ tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms]
+ A: float. The normalization factor. Default None.
+
+'''
+
+
+alpha_syn_doc = r'''
+
+ **Model Descriptions**
+
+ The analytical expression of alpha synapse is given by:
+
+ .. math::
+
+ g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right).
+
+ While, this equation is hard to implement. So, let's try to convert it into the
+ differential forms:
+
+ .. math::
+
+ \begin{aligned}
+ &\frac{d g}{d t}=-\frac{g}{\tau}+\frac{h}{\tau} \\
+ &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right)
+ \end{aligned}
+
+ .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
+ "The Synapse." Principles of Computational Modelling in Neuroscience.
+ Cambridge: Cambridge UP, 2011. 172-95. Print.
+
+
+'''
+
+
+exp_syn_doc = r'''
+
+ **Model Descriptions**
+
+ The single exponential decay synapse model assumes the release of neurotransmitter,
+ its diffusion across the cleft, the receptor binding, and channel opening all happen
+ very quickly, so that the channels instantaneously jump from the closed to the open state.
+ Therefore, its expression is given by
+
+ .. math::
+
+ g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau}
+
+ where :math:`\tau_{delay}` is the time constant of the synaptic state decay,
+ :math:`t_0` is the time of the pre-synaptic spike,
+ :math:`g_{\mathrm{max}}` is the maximal conductance.
+
+ Accordingly, the differential form of the exponential synapse is given by
+
+ .. math::
+
+ \begin{aligned}
+ & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
+ \end{aligned}
+
+ .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
+ "The Synapse." Principles of Computational Modelling in Neuroscience.
+ Cambridge: Cambridge UP, 2011. 172-95. Print.
+
+'''
+
+
+std_doc = r'''
+
+ This model filters the synaptic current by the following equation:
+
+ .. math::
+
+ I_{syn}^+(t) = I_{syn}^-(t) * x
+
+ where :math:`x` is the normalized variable between 0 and 1, and
+ :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
+ and after STD filtering.
+
+ Moreover, :math:`x` is updated according to the dynamics of:
+
+ .. math::
+
+ \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike})
+
+ where :math:`U` is the fraction of resources used per action potential,
+ :math:`\tau` is the time constant of recovery of the synaptic vesicles.
+
+'''
+
+
+stp_doc = r'''
+
+ This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`.
+
+ .. math::
+
+ I_{syn}^+(t) = I_{syn}^-(t) * x * u
+
+ where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
+ and after STP filtering, :math:`x` denotes the fraction of resources that remain available
+ after neurotransmitter depletion, and :math:`u` represents the fraction of available
+ resources ready for use (release probability).
+
+ The dynamics of :math:`u` and :math:`x` are governed by
+
+ .. math::
+
+ \begin{aligned}
+ \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\
+ \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\
+ \end{aligned}
+
+ where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment
+ of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding
+ variables just before the arrival of the spike, and :math:`u^+`
+ refers to the moment just after the spike.
+
+
+'''
+
diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py
index 7a985cb9d..f9145a94b 100644
--- a/brainpy/_src/dyn/neurons/hh.py
+++ b/brainpy/_src/dyn/neurons/hh.py
@@ -61,7 +61,7 @@ class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode):
where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants.
.. versionadded:: 2.1.9
- Model the conductance-based neuron model.
+ Modeling the conductance-based neuron model.
Parameters
----------
@@ -117,7 +117,7 @@ def __init__(
def derivative(self, V, t, I):
# synapses
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
# channels
for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values():
I = I + ch.current(V)
@@ -140,7 +140,7 @@ def update(self, x=None):
x = x * (1e-3 / self.A)
# integral
- V = self.integral(self.V.value, share['t'], x, share['dt'])
+ V = self.integral(self.V.value, share['t'], x, share['dt']) + self.sum_delta_inputs()
# check whether the children channels have the correct parents.
channels = self.nodes(level=1, include_self=False).subset(IonChaDyn).unique()
@@ -176,7 +176,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
# inputs
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -348,7 +348,8 @@ def __init__(
self.reset_state(self.mode)
# m channel
- m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
+ # m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
+ m_alpha = lambda self, V: 1. / bm.exprel(-(V + 40) / 10)
m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18)
m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m
@@ -360,7 +361,8 @@ def __init__(
dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h
# n channel
- n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
+ # n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
+ n_alpha = lambda self, V: 0.1 / bm.exprel(-(V + 55) / 10)
n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80)
n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n
@@ -382,9 +384,10 @@ def reset_state(self, batch_size=None, **kwargs):
self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size)
def dV(self, V, t, m, h, n, I):
- I = self.sum_inputs(V, init=I)
- I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
- I_K = (self.gK * n ** 4.0) * (V - self.EK)
+ I = self.sum_current_inputs(V, init=I)
+ I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
+ n2 = n * n
+ I_K = (self.gK * n2 * n2) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + I) / self.C
return dVdt
@@ -399,6 +402,7 @@ def update(self, x=None):
x = 0. if x is None else x
V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt)
+ V += self.sum_delta_inputs()
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.m.value = m
@@ -516,8 +520,9 @@ class HH(HHLTC):
"""
def dV(self, V, t, m, h, n, I):
- I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
- I_K = (self.gK * n ** 4.0) * (V - self.EK)
+ I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
+ n2 = n * n
+ I_K = (self.gK * n2 * n2) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + I) / self.C
return dVdt
@@ -528,7 +533,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -658,7 +663,7 @@ def reset_state(self, batch_or_mode=None, **kwargs):
self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_or_mode)
def dV(self, V, t, W, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
I_Ca = self.g_Ca * M_inf * (V - self.V_Ca)
I_K = self.g_K * W * (V - self.V_K)
@@ -680,9 +685,8 @@ def update(self, x=None):
t = share.load('t')
dt = share.load('dt')
x = 0. if x is None else x
-
V, W = self.integral(self.V, self.W, t, x, dt)
-
+ V += self.sum_delta_inputs()
spike = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.W.value = W
@@ -759,7 +763,7 @@ def dV(self, V, t, W, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -930,7 +934,8 @@ def reset_state(self, batch_size=None):
self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size)
def m_inf(self, V):
- alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
+ # alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
+ alpha = 1. / bm.exprel(-0.1 * (V + 35))
beta = 4. * bm.exp(-(V + 60.) / 18.)
return alpha / (alpha + beta)
@@ -941,13 +946,14 @@ def dh(self, h, t, V):
return self.phi * dhdt
def dn(self, n, t, V):
- alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
+ # alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
+ alpha = 1. / bm.exprel(-0.1 * (V + 34))
beta = 0.125 * bm.exp(-(V + 44) / 80)
dndt = alpha * (1 - n) - beta * n
return self.phi * dndt
def dV(self, V, t, h, n, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa)
IK = self.gK * n ** 4 * (V - self.EK)
IL = self.gL * (V - self.EL)
@@ -964,6 +970,7 @@ def update(self, x=None):
x = 0. if x is None else x
V, h, n = self.integral(self.V, self.h, self.n, t, x, dt)
+ V += self.sum_delta_inputs()
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.h.value = h
@@ -1087,5 +1094,5 @@ def dV(self, V, t, h, n, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py
index 988c915ac..11934d9dc 100644
--- a/brainpy/_src/dyn/neurons/lif.py
+++ b/brainpy/_src/dyn/neurons/lif.py
@@ -5,12 +5,12 @@
import brainpy.math as bm
from brainpy._src.context import share
+from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc
+from brainpy._src.dyn.neurons.base import GradNeuDyn
from brainpy._src.initialize import ZeroInit, OneInit
from brainpy._src.integrators import odeint, JointEq
from brainpy.check import is_initializer
from brainpy.types import Shape, ArrayType, Sharding
-from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc
-from brainpy._src.dyn.neurons.base import GradNeuDyn
__all__ = [
'IF',
@@ -119,7 +119,7 @@ def __init__(
self.reset_state(self.mode)
def derivative(self, V, t, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
return (-V + self.V_rest + self.R * I) / self.tau
def reset_state(self, batch_size=None, **kwargs):
@@ -132,7 +132,7 @@ def update(self, x=None):
x = 0. if x is None else x
# integrate membrane potential
- self.V.value = self.integral(self.V.value, t, x, dt)
+ self.V.value = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs()
return self.V.value
@@ -146,7 +146,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -252,7 +252,7 @@ def __init__(
self.reset_state(self.mode)
def derivative(self, V, t, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
return (-V + self.V_rest + self.R * I) / self.tau
def reset_state(self, batch_size=None, **kwargs):
@@ -265,7 +265,7 @@ def update(self, x=None):
x = 0. if x is None else x
# integrate membrane potential
- V = self.integral(self.V.value, t, x, dt)
+ V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs()
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
@@ -337,7 +337,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -464,7 +464,7 @@ def update(self, x=None):
x = 0. if x is None else x
# integrate membrane potential
- V = self.integral(self.V.value, t, x, dt)
+ V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs()
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
@@ -552,7 +552,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -723,7 +723,7 @@ def __init__(
self.reset_state(self.mode)
def derivative(self, V, t, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau
return dvdt
@@ -738,7 +738,7 @@ def update(self, x=None):
x = 0. if x is None else x
# integrate membrane potential
- V = self.integral(self.V.value, t, x, dt)
+ V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs()
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
@@ -880,7 +880,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -994,6 +994,7 @@ class ExpIFRefLTC(ExpIFLTC):
%s
"""
+
def __init__(
self,
size: Shape,
@@ -1076,7 +1077,7 @@ def update(self, x=None):
x = 0. if x is None else x
# integrate membrane potential
- V = self.integral(self.V.value, t, x, dt)
+ V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs()
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
@@ -1221,6 +1222,7 @@ class ExpIFRef(ExpIFRefLTC):
%s
%s
"""
+
def derivative(self, V, t, I):
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau
@@ -1228,7 +1230,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -1400,7 +1402,7 @@ def __init__(
self.reset_state(self.mode)
def dV(self, V, t, w, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau
return dVdt
@@ -1425,6 +1427,7 @@ def update(self, x=None):
# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
+ V += self.sum_delta_inputs()
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
@@ -1559,7 +1562,7 @@ def dV(self, V, t, w, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -1757,6 +1760,7 @@ def update(self, x=None):
# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
+ V += self.sum_delta_inputs()
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
@@ -1901,7 +1905,7 @@ def dV(self, V, t, w, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -2040,7 +2044,7 @@ def __init__(
self.reset_state(self.mode)
def derivative(self, V, t, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau
return dVdt
@@ -2054,7 +2058,7 @@ def update(self, x=None):
x = 0. if x is None else x
# integrate membrane potential
- V = self.integral(self.V.value, t, x, dt)
+ V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs()
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
@@ -2166,7 +2170,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -2330,7 +2334,7 @@ def update(self, x=None):
x = 0. if x is None else x
# integrate membrane potential
- V = self.integral(self.V.value, t, x, dt)
+ V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs()
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
@@ -2444,14 +2448,13 @@ class QuaIFRef(QuaIFRefLTC):
%s
"""
-
def derivative(self, V, t, I):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau
return dVdt
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -2609,7 +2612,7 @@ def __init__(
self.reset_state(self.mode)
def dV(self, V, t, w, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau
return dVdt
@@ -2633,6 +2636,7 @@ def update(self, x=None):
# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
+ V += self.sum_delta_inputs()
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
@@ -2756,7 +2760,7 @@ def dV(self, V, t, w, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -2939,6 +2943,7 @@ def update(self, x=None):
# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
+ V += self.sum_delta_inputs()
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
@@ -3072,7 +3077,7 @@ def dV(self, V, t, w, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -3279,7 +3284,7 @@ def dVth(self, V_th, t, V):
return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf)
def dV(self, V, t, I1, I2, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau
@property
@@ -3300,6 +3305,7 @@ def update(self, x=None):
# integrate membrane potential
I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt)
+ V += self.sum_delta_inputs()
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
@@ -3452,7 +3458,7 @@ def dV(self, V, t, I1, I2, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -3573,7 +3579,6 @@ class GifRefLTC(GifLTC):
%s
"""
-
def __init__(
self,
size: Shape,
@@ -3680,6 +3685,7 @@ def update(self, x=None):
# integrate membrane potential
I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt)
+ V += self.sum_delta_inputs()
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
@@ -3840,13 +3846,12 @@ class GifRef(GifRefLTC):
%s
"""
-
def dV(self, V, t, I1, I2, I):
return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -4012,7 +4017,7 @@ def __init__(
self.reset_state(self.mode)
def dV(self, V, t, u, I):
- I = self.sum_inputs(V, init=I)
+ I = self.sum_current_inputs(V, init=I)
dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I
return dVdt
@@ -4040,6 +4045,7 @@ def update(self, x=None):
# integrate membrane potential
V, u = self.integral(self.V.value, self.u.value, t, x, dt)
+ V += self.sum_delta_inputs()
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
@@ -4161,7 +4167,7 @@ def dV(self, V, t, u, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
@@ -4351,6 +4357,7 @@ def update(self, x=None):
# integrate membrane potential
V, u = self.integral(self.V.value, self.u.value, t, x, dt)
+ V += self.sum_delta_inputs()
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
@@ -4485,11 +4492,11 @@ def dV(self, V, t, u, I):
def update(self, x=None):
x = 0. if x is None else x
- x = self.sum_inputs(self.V.value, init=x)
+ x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)
-Izhikevich.__doc__ = Izhikevich.__doc__ %(pneu_doc, dpneu_doc)
-IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ %(pneu_doc, dpneu_doc, ref_doc)
-IzhikevichRef.__doc__ = IzhikevichRef.__doc__ %(pneu_doc, dpneu_doc, ref_doc)
-IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ %()
+Izhikevich.__doc__ = Izhikevich.__doc__ % (pneu_doc, dpneu_doc)
+IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc)
+IzhikevichRef.__doc__ = IzhikevichRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc)
+IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ % ()
diff --git a/brainpy/_src/dyn/others/common.py b/brainpy/_src/dyn/others/common.py
index 7cf4f98b8..812375787 100644
--- a/brainpy/_src/dyn/others/common.py
+++ b/brainpy/_src/dyn/others/common.py
@@ -77,7 +77,7 @@ def update(self, inp=None):
dt = share.load('dt')
self.x.value = self.integral(self.x.value, t, dt)
if inp is None: inp = 0.
- inp = self.sum_inputs(self.x.value, init=inp)
+ inp = self.sum_current_inputs(self.x.value, init=inp)
self.x += inp
return self.x.value
diff --git a/brainpy/_src/dyn/outs/outputs.py b/brainpy/_src/dyn/outs/outputs.py
index 5dc54a232..8171367d7 100644
--- a/brainpy/_src/dyn/outs/outputs.py
+++ b/brainpy/_src/dyn/outs/outputs.py
@@ -82,7 +82,7 @@ def __init__(
super().__init__(name=name, scaling=scaling)
def update(self, conductance, potential=None):
- return self.std_scaling(conductance)
+ return conductance
class MgBlock(SynOut):
@@ -138,5 +138,5 @@ def __init__(
self.beta = init.parameter(beta, np.shape(beta), sharding=sharding)
def update(self, conductance, potential):
- return conductance *\
- (self.E - potential) / (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential)))
+ norm = (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential)))
+ return conductance * (self.E - potential) / norm
diff --git a/brainpy/_src/dyn/projections/__init__.py b/brainpy/_src/dyn/projections/__init__.py
index 8a7040824..e69de29bb 100644
--- a/brainpy/_src/dyn/projections/__init__.py
+++ b/brainpy/_src/dyn/projections/__init__.py
@@ -1,5 +0,0 @@
-
-from .aligns import *
-from .conn import *
-from .others import *
-from .inputs import *
diff --git a/brainpy/_src/dyn/projections/align_post.py b/brainpy/_src/dyn/projections/align_post.py
new file mode 100644
index 000000000..9bd280f81
--- /dev/null
+++ b/brainpy/_src/dyn/projections/align_post.py
@@ -0,0 +1,507 @@
+from typing import Optional, Callable, Union
+
+from brainpy import math as bm, check
+from brainpy._src.delay import (delay_identifier,
+ register_delay_by_return)
+from brainpy._src.dynsys import DynamicalSystem, Projection
+from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost)
+
+__all__ = [
+ 'HalfProjAlignPostMg', 'FullProjAlignPostMg',
+ 'HalfProjAlignPost', 'FullProjAlignPost',
+
+]
+
+
+def get_post_repr(out_label, syn, out):
+ return f'{out_label} // {syn.identifier} // {out.identifier}'
+
+
+def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name):
+ # synapse and output initialization
+ _post_repr = get_post_repr(out_label, syn_desc, out_desc)
+ if not post.has_bef_update(_post_repr):
+ syn_cls = syn_desc()
+ out_cls = out_desc()
+
+ # synapse and output initialization
+ post.add_inp_fun(proj_name, out_cls, label=out_label)
+ post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls))
+ syn = post.get_bef_update(_post_repr).syn
+ out = post.get_bef_update(_post_repr).out
+ return syn, out
+
+
+class _AlignPost(DynamicalSystem):
+ def __init__(self,
+ syn: Callable,
+ out: JointType[DynamicalSystem, BindCondData]):
+ super().__init__()
+ self.syn = syn
+ self.out = out
+
+ def update(self, *args, **kwargs):
+ self.out.bind_cond(self.syn(*args, **kwargs))
+
+ def reset_state(self, *args, **kwargs):
+ pass
+
+
+class HalfProjAlignPostMg(Projection):
+ r"""Defining the half part of synaptic projection with the align-post reduction and the automatic synapse merging.
+
+ The ``half-part`` means that the model only needs to provide half information needed for a projection,
+ including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs
+ the manual providing of the spiking input.
+
+ The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
+
+ The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
+ parameters (such like time constants) will also share the same synaptic variables.
+
+ All align-post projection models prefer to use the event-driven computation mode. This means that the
+ ``comm`` model should be the event-driven model.
+
+ **Code Examples**
+
+ To define an E/I balanced network model.
+
+ .. code-block:: python
+
+ import brainpy as bp
+ import brainpy.math as bm
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
+ self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=4000, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.N)
+ self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=4000, tau=10.),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.N)
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(spk[:3200])
+ self.I(spk[3200:])
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+ Args:
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ out_label: str. The prefix of the output function.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ comm: DynamicalSystem,
+ syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]],
+ out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]])
+ check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # synapse and output initialization
+ syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
+
+ # references
+ self.refs = dict(post=post) # invisible to ``self.nodes()``
+ self.refs['syn'] = syn
+ self.refs['out'] = out
+ self.refs['comm'] = comm # unify the access
+
+ def update(self, x):
+ current = self.comm(x)
+ self.refs['syn'].add_current(current) # synapse post current
+ return current
+
+ syn = property(lambda self: self.refs['syn'])
+ out = property(lambda self: self.refs['out'])
+ post = property(lambda self: self.refs['post'])
+
+
+class FullProjAlignPostMg(Projection):
+ """Full-chain synaptic projection with the align-post reduction and the automatic synapse merging.
+
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
+ including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``.
+
+ The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
+
+ The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
+ parameters (such like time constants) will also share the same synaptic variables.
+
+ All align-post projection models prefer to use the event-driven computation mode. This means that the
+ ``comm`` model should be the event-driven model.
+
+ Moreover, it's worth noting that ``FullProjAlignPostMg`` has a different updating order with all align-pre
+ projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``.
+ While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``.
+
+ **Code Examples**
+
+ To define an E/I balanced network model.
+
+ .. code-block:: python
+
+ import brainpy as bp
+ import brainpy.math as bm
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.FullProjAlignPostMg(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPostMg(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=ni, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPostMg(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=ne, tau=10.),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPostMg(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ delay: The synaptic delay.
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
+ delay: Union[None, int, float],
+ comm: DynamicalSystem,
+ syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]],
+ out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]])
+ check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # delay initialization
+ delay_cls = register_delay_by_return(pre)
+ delay_cls.register_entry(self.name, delay)
+
+ # synapse and output initialization
+ syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
+
+ # references
+ self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
+ self.refs['syn'] = syn # invisible to ``self.node()``
+ self.refs['out'] = out # invisible to ``self.node()``
+ # unify the access
+ self.refs['comm'] = comm
+ self.refs['delay'] = pre.get_aft_update(delay_identifier)
+
+ def update(self):
+ x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name)
+ current = self.comm(x)
+ self.refs['syn'].add_current(current) # synapse post current
+ return current
+
+ syn = property(lambda self: self.refs['syn'])
+ out = property(lambda self: self.refs['out'])
+ delay = property(lambda self: self.refs['delay'])
+ pre = property(lambda self: self.refs['pre'])
+ post = property(lambda self: self.refs['post'])
+
+
+class HalfProjAlignPost(Projection):
+ """Defining the half-part of synaptic projection with the align-post reduction.
+
+ The ``half-part`` means that the model only needs to provide half information needed for a projection,
+ including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs
+ the manual providing of the spiking input.
+
+ The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
+
+ All align-post projection models prefer to use the event-driven computation mode. This means that the
+ ``comm`` model should be the event-driven model.
+
+ To simulate an E/I balanced network:
+
+ .. code-block::
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
+ self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon(size=4000, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.N)
+ self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon(size=4000, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.N)
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(spk[:3200])
+ self.I(spk[3200:])
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ comm: DynamicalSystem,
+ syn: JointType[DynamicalSystem, AlignPost],
+ out: JointType[DynamicalSystem, BindCondData],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+ self.syn = syn
+ self.out = out
+
+ # synapse and output initialization
+ post.add_inp_fun(self.name, out, label=out_label)
+
+ # reference
+ self.refs = dict()
+ # invisible to ``self.nodes()``
+ self.refs['post'] = post
+ self.refs['syn'] = syn
+ self.refs['out'] = out
+ # unify the access
+ self.refs['comm'] = comm
+
+ def update(self, x):
+ current = self.comm(x)
+ g = self.syn(self.comm(x))
+ self.refs['out'].bind_cond(g) # synapse post current
+ return current
+
+ post = property(lambda self: self.refs['post'])
+
+
+class FullProjAlignPost(Projection):
+ """Full-chain synaptic projection with the align-post reduction.
+
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
+ including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``.
+
+ The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
+
+ All align-post projection models prefer to use the event-driven computation mode. This means that the
+ ``comm`` model should be the event-driven model.
+
+ Moreover, it's worth noting that ``FullProjAlignPost`` has a different updating order with all align-pre
+ projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``.
+ While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``.
+
+ To simulate and define an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.FullProjAlignPost(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon(size=ne, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPost(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon(size=ni, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPost(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon(size=ne, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPost(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon(size=ni, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ delay: The synaptic delay.
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
+ delay: Union[None, int, float],
+ comm: DynamicalSystem,
+ syn: JointType[DynamicalSystem, AlignPost],
+ out: JointType[DynamicalSystem, BindCondData],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+ self.syn = syn
+
+ # delay initialization
+ delay_cls = register_delay_by_return(pre)
+ delay_cls.register_entry(self.name, delay)
+
+ # synapse and output initialization
+ post.add_inp_fun(self.name, out, label=out_label)
+
+ # references
+ self.refs = dict()
+ # invisible to ``self.nodes()``
+ self.refs['pre'] = pre
+ self.refs['post'] = post
+ self.refs['out'] = out
+ # unify the access
+ self.refs['delay'] = delay_cls
+ self.refs['comm'] = comm
+ self.refs['syn'] = syn
+
+ def update(self):
+ x = self.refs['delay'].at(self.name)
+ g = self.syn(self.comm(x))
+ self.refs['out'].bind_cond(g) # synapse post current
+ return g
+
+ delay = property(lambda self: self.refs['delay'])
+ pre = property(lambda self: self.refs['pre'])
+ post = property(lambda self: self.refs['post'])
+ out = property(lambda self: self.refs['out'])
diff --git a/brainpy/_src/dyn/projections/align_pre.py b/brainpy/_src/dyn/projections/align_pre.py
new file mode 100644
index 000000000..6e5cd223a
--- /dev/null
+++ b/brainpy/_src/dyn/projections/align_pre.py
@@ -0,0 +1,606 @@
+from typing import Optional, Union
+
+from brainpy import math as bm, check
+from brainpy._src.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return)
+from brainpy._src.dynsys import DynamicalSystem, Projection
+from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData)
+from .utils import _get_return
+
+__all__ = [
+ 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg',
+ 'FullProjAlignPreSD', 'FullProjAlignPreDS',
+]
+
+
+def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None):
+ _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}'
+ if not delay_cls.has_bef_update(_syn_id):
+ # delay
+ delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name)
+ # synapse
+ syn_cls = syn_desc()
+ # add to "after_updates"
+ delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
+ syn = delay_cls.get_bef_update(_syn_id).syn
+ return syn
+
+
+class _AlignPreMg(DynamicalSystem):
+ def __init__(self, access, syn):
+ super().__init__()
+ self.access = access
+ self.syn = syn
+
+ def update(self, *args, **kwargs):
+ return self.syn(self.access())
+
+ def reset_state(self, *args, **kwargs):
+ pass
+
+
+def align_pre1_add_bef_update(syn_desc, pre):
+ _syn_id = f'{syn_desc.identifier} // Delay'
+ if not pre.has_aft_update(_syn_id):
+ # "syn_cls" needs an instance of "ProjAutoDelay"
+ syn_cls: SupportAutoDelay = syn_desc()
+ delay_cls = init_delay_by_return(syn_cls.return_info())
+ # add to "after_updates"
+ pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls))
+ delay_cls: Delay = pre.get_aft_update(_syn_id).delay
+ syn = pre.get_aft_update(_syn_id).syn
+ return delay_cls, syn
+
+
+class _AlignPre(DynamicalSystem):
+ def __init__(self, syn, delay=None):
+ super().__init__()
+ self.syn = syn
+ self.delay = delay
+
+ def update(self, x):
+ if self.delay is None:
+ return x >> self.syn
+ else:
+ return x >> self.syn >> self.delay
+
+ def reset_state(self, *args, **kwargs):
+ pass
+
+
+class FullProjAlignPreSDMg(Projection):
+ """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging.
+
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
+ including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
+
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
+
+ The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
+ synapse states to the delay model, and finally computes the synaptic current.
+
+ The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
+ parameters (such like time constants) will also share the same synaptic variables.
+
+ Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation.
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
+
+ To simulate an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ syn: The synaptic dynamics.
+ delay: The synaptic delay.
+ comm: The synaptic communication.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: DynamicalSystem,
+ syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]],
+ delay: Union[None, int, float],
+ comm: DynamicalSystem,
+ out: JointType[DynamicalSystem, BindCondData],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, DynamicalSystem)
+ check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # synapse and delay initialization
+ delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre)
+ delay_cls.register_entry(self.name, delay)
+
+ # output initialization
+ post.add_inp_fun(self.name, out, label=out_label)
+
+ # references
+ self.refs = dict()
+ # invisible to ``self.nodes()``
+ self.refs['pre'] = pre
+ self.refs['post'] = post
+ self.refs['out'] = out
+ self.refs['delay'] = delay_cls
+ self.refs['syn'] = syn_cls
+ # unify the access
+ self.refs['comm'] = comm
+
+ def update(self, x=None):
+ if x is None:
+ x = self.refs['delay'].at(self.name)
+ current = self.comm(x)
+ self.refs['out'].bind_cond(current)
+ return current
+
+ pre = property(lambda self: self.refs['pre'])
+ post = property(lambda self: self.refs['post'])
+ syn = property(lambda self: self.refs['syn'])
+ delay = property(lambda self: self.refs['delay'])
+ out = property(lambda self: self.refs['out'])
+
+
+class FullProjAlignPreDSMg(Projection):
+ """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
+
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
+ including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``.
+ Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged.
+
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
+
+ The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
+ spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
+
+ The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
+ parameters (such like time constants) will also share the same synaptic variables.
+
+ Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation.
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
+
+
+ To simulate an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ delay: The synaptic delay.
+ syn: The synaptic dynamics.
+ comm: The synaptic communication.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
+ delay: Union[None, int, float],
+ syn: ParamDescriber[DynamicalSystem],
+ comm: DynamicalSystem,
+ out: JointType[DynamicalSystem, BindCondData],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+ check.is_instance(syn, ParamDescriber[DynamicalSystem])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # delay initialization
+ delay_cls = register_delay_by_return(pre)
+
+ # synapse initialization
+ syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name)
+
+ # output initialization
+ post.add_inp_fun(self.name, out, label=out_label)
+
+ # references
+ self.refs = dict()
+ # invisible to `self.nodes()`
+ self.refs['pre'] = pre
+ self.refs['post'] = post
+ self.refs['syn'] = syn_cls
+ self.refs['out'] = out
+ # unify the access
+ self.refs['comm'] = comm
+
+ def update(self):
+ x = _get_return(self.refs['syn'].return_info())
+ current = self.comm(x)
+ self.refs['out'].bind_cond(current)
+ return current
+
+ pre = property(lambda self: self.refs['pre'])
+ post = property(lambda self: self.refs['post'])
+ syn = property(lambda self: self.refs['syn'])
+ out = property(lambda self: self.refs['out'])
+
+
+class FullProjAlignPreSD(Projection):
+ """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
+
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
+ including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
+
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
+
+ The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
+ synapse states to the delay model, and finally computes the synaptic current.
+
+ Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation.
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
+
+
+ To simulate an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ syn: The synaptic dynamics.
+ delay: The synaptic delay.
+ comm: The synaptic communication.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: DynamicalSystem,
+ syn: JointType[DynamicalSystem, SupportAutoDelay],
+ delay: Union[None, int, float],
+ comm: DynamicalSystem,
+ out: JointType[DynamicalSystem, BindCondData],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, DynamicalSystem)
+ check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # synapse and delay initialization
+ delay_cls = init_delay_by_return(syn.return_info())
+ delay_cls.register_entry(self.name, delay)
+ pre.add_aft_update(self.name, _AlignPre(syn, delay_cls))
+
+ # output initialization
+ post.add_inp_fun(self.name, out, label=out_label)
+
+ # references
+ self.refs = dict()
+ # invisible to ``self.nodes()``
+ self.refs['pre'] = pre
+ self.refs['post'] = post
+ self.refs['out'] = out
+ self.refs['delay'] = delay_cls
+ self.refs['syn'] = syn
+ # unify the access
+ self.refs['comm'] = comm
+
+ def update(self, x=None):
+ if x is None:
+ x = self.refs['delay'].at(self.name)
+ current = self.comm(x)
+ self.refs['out'].bind_cond(current)
+ return current
+
+ pre = property(lambda self: self.refs['pre'])
+ post = property(lambda self: self.refs['post'])
+ syn = property(lambda self: self.refs['syn'])
+ delay = property(lambda self: self.refs['delay'])
+ out = property(lambda self: self.refs['out'])
+
+
+class FullProjAlignPreDS(Projection):
+ """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
+
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
+ including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
+ Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged.
+
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
+
+ The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
+ spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
+
+ Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation.
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
+
+
+ To simulate an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ delay: The synaptic delay.
+ syn: The synaptic dynamics.
+ comm: The synaptic communication.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
+ delay: Union[None, int, float],
+ syn: DynamicalSystem,
+ comm: DynamicalSystem,
+ out: JointType[DynamicalSystem, BindCondData],
+ post: DynamicalSystem,
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+ check.is_instance(syn, DynamicalSystem)
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+ self.syn = syn
+
+ # delay initialization
+ delay_cls = register_delay_by_return(pre)
+ delay_cls.register_entry(self.name, delay)
+
+ # output initialization
+ post.add_inp_fun(self.name, out, label=out_label)
+
+ # references
+ self.refs = dict()
+ # invisible to ``self.nodes()``
+ self.refs['pre'] = pre
+ self.refs['post'] = post
+ self.refs['out'] = out
+ self.refs['delay'] = delay_cls
+ # unify the access
+ self.refs['syn'] = syn
+ self.refs['comm'] = comm
+
+ def update(self):
+ spk = self.refs['delay'].at(self.name)
+ g = self.comm(self.syn(spk))
+ self.refs['out'].bind_cond(g)
+ return g
+
+ pre = property(lambda self: self.refs['pre'])
+ post = property(lambda self: self.refs['post'])
+ delay = property(lambda self: self.refs['delay'])
+ out = property(lambda self: self.refs['out'])
+
diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py
deleted file mode 100644
index 2616e928b..000000000
--- a/brainpy/_src/dyn/projections/aligns.py
+++ /dev/null
@@ -1,1053 +0,0 @@
-from typing import Optional, Callable, Union
-
-from brainpy import math as bm, check
-from brainpy._src.delay import (Delay, DelayAccess, delay_identifier,
- init_delay_by_return, register_delay_by_return)
-from brainpy._src.dynsys import DynamicalSystem, Projection
-from brainpy._src.mixin import (JointType, ParamDescriber, ReturnInfo,
- SupportAutoDelay, BindCondData, AlignPost)
-
-__all__ = [
- 'VanillaProj',
- 'ProjAlignPostMg1', 'ProjAlignPostMg2',
- 'ProjAlignPost1', 'ProjAlignPost2',
- 'ProjAlignPreMg1', 'ProjAlignPreMg2',
- 'ProjAlignPre1', 'ProjAlignPre2',
-]
-
-
-def get_post_repr(out_label, syn, out):
- return f'{out_label} // {syn.identifier} // {out.identifier}'
-
-
-def add_inp_fun(out_label, proj_name, out, post):
- # synapse and output initialization
- if out_label is None:
- out_name = proj_name
- else:
- out_name = f'{out_label} // {proj_name}'
- post.add_inp_fun(out_name, out)
-
-
-def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name):
- # synapse and output initialization
- _post_repr = get_post_repr(out_label, syn_desc, out_desc)
- if not post.has_bef_update(_post_repr):
- syn_cls = syn_desc()
- out_cls = out_desc()
-
- # synapse and output initialization
- if out_label is None:
- out_name = proj_name
- else:
- out_name = f'{out_label} // {proj_name}'
- post.add_inp_fun(out_name, out_cls)
- post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls))
- syn = post.get_bef_update(_post_repr).syn
- out = post.get_bef_update(_post_repr).out
- return syn, out
-
-
-def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None):
- _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}'
- if not delay_cls.has_bef_update(_syn_id):
- # delay
- delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name)
- # synapse
- syn_cls = syn_desc()
- # add to "after_updates"
- delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
- syn = delay_cls.get_bef_update(_syn_id).syn
- return syn
-
-
-def align_pre1_add_bef_update(syn_desc, pre):
- _syn_id = f'{syn_desc.identifier} // Delay'
- if not pre.has_aft_update(_syn_id):
- # "syn_cls" needs an instance of "ProjAutoDelay"
- syn_cls: SupportAutoDelay = syn_desc()
- delay_cls = init_delay_by_return(syn_cls.return_info())
- # add to "after_updates"
- pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls))
- delay_cls: Delay = pre.get_aft_update(_syn_id).delay
- syn = pre.get_aft_update(_syn_id).syn
- return delay_cls, syn
-
-
-class _AlignPre(DynamicalSystem):
- def __init__(self, syn, delay=None):
- super().__init__()
- self.syn = syn
- self.delay = delay
-
- def update(self, x):
- if self.delay is None:
- return x >> self.syn
- else:
- return x >> self.syn >> self.delay
-
- def reset_state(self, *args, **kwargs):
- pass
-
-
-class _AlignPost(DynamicalSystem):
- def __init__(self,
- syn: Callable,
- out: JointType[DynamicalSystem, BindCondData]):
- super().__init__()
- self.syn = syn
- self.out = out
-
- def update(self, *args, **kwargs):
- self.out.bind_cond(self.syn(*args, **kwargs))
-
- def reset_state(self, *args, **kwargs):
- pass
-
-
-class _AlignPreMg(DynamicalSystem):
- def __init__(self, access, syn):
- super().__init__()
- self.access = access
- self.syn = syn
-
- def update(self, *args, **kwargs):
- return self.syn(self.access())
-
- def reset_state(self, *args, **kwargs):
- pass
-
-
-def _get_return(return_info):
- if isinstance(return_info, bm.Variable):
- return return_info.value
- elif isinstance(return_info, ReturnInfo):
- return return_info.get_data()
- else:
- raise NotImplementedError
-
-
-class VanillaProj(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group.
-
- **Code Examples**
-
- To simulate an E/I balanced network model:
-
- .. code-block::
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
- self.syn1 = bp.dyn.Expon(size=3200, tau=5.)
- self.syn2 = bp.dyn.Expon(size=800, tau=10.)
- self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.N)
- self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.N)
-
- def update(self, input):
- spk = self.delay.at('I')
- self.E(self.syn1(spk[:3200]))
- self.I(self.syn2(spk[3200:]))
- self.delay(self.N(input))
- return self.N.spike.value
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
-
- Args:
- comm: The synaptic communication.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- comm: DynamicalSystem,
- out: JointType[DynamicalSystem, BindCondData],
- post: DynamicalSystem,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(out, JointType[DynamicalSystem, BindCondData])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
-
- # output initialization
- post.add_inp_fun(self.name, out)
-
- # references
- self.refs = dict(post=post, out=out) # invisible to ``self.nodes()``
- self.refs['comm'] = comm # unify the access
-
- def update(self, x):
- current = self.comm(x)
- self.refs['out'].bind_cond(current)
- return current
-
-
-class ProjAlignPostMg1(Projection):
- r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
-
- **Code Examples**
-
- To define an E/I balanced network model.
-
- .. code-block:: python
-
- import brainpy as bp
- import brainpy.math as bm
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
- self.E = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
- syn=bp.dyn.Expon.desc(size=4000, tau=5.),
- out=bp.dyn.COBA.desc(E=0.),
- post=self.N)
- self.I = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
- syn=bp.dyn.Expon.desc(size=4000, tau=10.),
- out=bp.dyn.COBA.desc(E=-80.),
- post=self.N)
-
- def update(self, input):
- spk = self.delay.at('I')
- self.E(spk[:3200])
- self.I(spk[3200:])
- self.delay(self.N(input))
- return self.N.spike.value
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
- Args:
- comm: The synaptic communication.
- syn: The synaptic dynamics.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- out_label: str. The prefix of the output function.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- comm: DynamicalSystem,
- syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]],
- out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]])
- check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
-
- # synapse and output initialization
- syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
-
- # references
- self.refs = dict(post=post) # invisible to ``self.nodes()``
- self.refs['syn'] = syn
- self.refs['out'] = out
- self.refs['comm'] = comm # unify the access
-
- def update(self, x):
- current = self.comm(x)
- self.refs['syn'].add_current(current) # synapse post current
- return current
-
-
-class ProjAlignPostMg2(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
-
- **Code Examples**
-
- To define an E/I balanced network model.
-
- .. code-block:: python
-
- import brainpy as bp
- import brainpy.math as bm
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- ne, ni = 3200, 800
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPostMg2(pre=self.E,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- out=bp.dyn.COBA.desc(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPostMg2(pre=self.E,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
- syn=bp.dyn.Expon.desc(size=ni, tau=5.),
- out=bp.dyn.COBA.desc(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPostMg2(pre=self.I,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
- syn=bp.dyn.Expon.desc(size=ne, tau=10.),
- out=bp.dyn.COBA.desc(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPostMg2(pre=self.I,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- out=bp.dyn.COBA.desc(E=-80.),
- post=self.I)
-
- def update(self, inp):
- self.E2E()
- self.E2I()
- self.I2E()
- self.I2I()
- self.E(inp)
- self.I(inp)
- return self.E.spike
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
- Args:
- pre: The pre-synaptic neuron group.
- delay: The synaptic delay.
- comm: The synaptic communication.
- syn: The synaptic dynamics.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- pre: JointType[DynamicalSystem, SupportAutoDelay],
- delay: Union[None, int, float],
- comm: DynamicalSystem,
- syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]],
- out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]])
- check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
-
- # delay initialization
- delay_cls = register_delay_by_return(pre)
- delay_cls.register_entry(self.name, delay)
-
- # synapse and output initialization
- syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
-
- # references
- self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
- self.refs['syn'] = syn # invisible to ``self.node()``
- self.refs['out'] = out # invisible to ``self.node()``
- # unify the access
- self.refs['comm'] = comm
- self.refs['delay'] = pre.get_aft_update(delay_identifier)
-
- def update(self):
- x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name)
- current = self.comm(x)
- self.refs['syn'].add_current(current) # synapse post current
- return current
-
-
-class ProjAlignPost1(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
-
- To simulate an E/I balanced network:
-
- .. code-block::
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
- self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
- syn=bp.dyn.Expon(size=4000, tau=5.),
- out=bp.dyn.COBA(E=0.),
- post=self.N)
- self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
- syn=bp.dyn.Expon(size=4000, tau=10.),
- out=bp.dyn.COBA(E=-80.),
- post=self.N)
-
- def update(self, input):
- spk = self.delay.at('I')
- self.E(spk[:3200])
- self.I(spk[3200:])
- self.delay(self.N(input))
- return self.N.spike.value
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
-
- Args:
- comm: The synaptic communication.
- syn: The synaptic dynamics.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- comm: DynamicalSystem,
- syn: JointType[DynamicalSystem, AlignPost],
- out: JointType[DynamicalSystem, BindCondData],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
- check.is_instance(out, JointType[DynamicalSystem, BindCondData])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
- self.syn = syn
- self.out = out
-
- # synapse and output initialization
- add_inp_fun(out_label, self.name, out, post)
-
- # reference
- self.refs = dict()
- # invisible to ``self.nodes()``
- self.refs['post'] = post
- self.refs['syn'] = syn
- self.refs['out'] = out
- # unify the access
- self.refs['comm'] = comm
-
- def update(self, x):
- current = self.comm(x)
- g = self.syn(self.comm(x))
- self.refs['out'].bind_cond(g) # synapse post current
- return current
-
-
-class ProjAlignPost2(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
-
- To simulate and define an E/I balanced network model:
-
- .. code-block:: python
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- ne, ni = 3200, 800
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPost2(pre=self.E,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
- syn=bp.dyn.Expon(size=ne, tau=5.),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPost2(pre=self.E,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
- syn=bp.dyn.Expon(size=ni, tau=5.),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPost2(pre=self.I,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
- syn=bp.dyn.Expon(size=ne, tau=10.),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPost2(pre=self.I,
- delay=0.1,
- comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
- syn=bp.dyn.Expon(size=ni, tau=10.),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
-
- def update(self, inp):
- self.E2E()
- self.E2I()
- self.I2E()
- self.I2I()
- self.E(inp)
- self.I(inp)
- return self.E.spike
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
-
- Args:
- pre: The pre-synaptic neuron group.
- delay: The synaptic delay.
- comm: The synaptic communication.
- syn: The synaptic dynamics.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- pre: JointType[DynamicalSystem, SupportAutoDelay],
- delay: Union[None, int, float],
- comm: DynamicalSystem,
- syn: JointType[DynamicalSystem, AlignPost],
- out: JointType[DynamicalSystem, BindCondData],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
- check.is_instance(out, JointType[DynamicalSystem, BindCondData])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
- self.syn = syn
-
- # delay initialization
- delay_cls = register_delay_by_return(pre)
- delay_cls.register_entry(self.name, delay)
-
- # synapse and output initialization
- add_inp_fun(out_label, self.name, out, post)
-
- # references
- self.refs = dict()
- # invisible to ``self.nodes()``
- self.refs['pre'] = pre
- self.refs['post'] = post
- self.refs['out'] = out
- # unify the access
- self.refs['delay'] = delay_cls
- self.refs['comm'] = comm
- self.refs['syn'] = syn
-
- def update(self):
- x = self.refs['delay'].at(self.name)
- g = self.syn(self.comm(x))
- self.refs['out'].bind_cond(g) # synapse post current
- return g
-
-
-class ProjAlignPreMg1(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
-
- To simulate an E/I balanced network model:
-
- .. code-block:: python
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- ne, ni = 3200, 800
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
-
- def update(self, inp):
- self.E2E()
- self.E2I()
- self.I2E()
- self.I2I()
- self.E(inp)
- self.I(inp)
- return self.E.spike
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
-
- Args:
- pre: The pre-synaptic neuron group.
- syn: The synaptic dynamics.
- delay: The synaptic delay.
- comm: The synaptic communication.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- pre: DynamicalSystem,
- syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]],
- delay: Union[None, int, float],
- comm: DynamicalSystem,
- out: JointType[DynamicalSystem, BindCondData],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(pre, DynamicalSystem)
- check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]])
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(out, JointType[DynamicalSystem, BindCondData])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
-
- # synapse and delay initialization
- delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre)
- delay_cls.register_entry(self.name, delay)
-
- # output initialization
- add_inp_fun(out_label, self.name, out, post)
-
- # references
- self.refs = dict()
- # invisible to ``self.nodes()``
- self.refs['pre'] = pre
- self.refs['post'] = post
- self.refs['out'] = out
- self.refs['delay'] = delay_cls
- self.refs['syn'] = syn_cls
- # unify the access
- self.refs['comm'] = comm
-
- def update(self, x=None):
- if x is None:
- x = self.refs['delay'].at(self.name)
- current = self.comm(x)
- self.refs['out'].bind_cond(current)
- return current
-
-
-class ProjAlignPreMg2(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
-
- To simulate an E/I balanced network model:
-
- .. code-block:: python
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- ne, ni = 3200, 800
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
-
- def update(self, inp):
- self.E2E()
- self.E2I()
- self.I2E()
- self.I2I()
- self.E(inp)
- self.I(inp)
- return self.E.spike
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
-
- Args:
- pre: The pre-synaptic neuron group.
- delay: The synaptic delay.
- syn: The synaptic dynamics.
- comm: The synaptic communication.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- pre: JointType[DynamicalSystem, SupportAutoDelay],
- delay: Union[None, int, float],
- syn: ParamDescriber[DynamicalSystem],
- comm: DynamicalSystem,
- out: JointType[DynamicalSystem, BindCondData],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
- check.is_instance(syn, ParamDescriber[DynamicalSystem])
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(out, JointType[DynamicalSystem, BindCondData])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
-
- # delay initialization
- delay_cls = register_delay_by_return(pre)
-
- # synapse initialization
- syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name)
-
- # output initialization
- add_inp_fun(out_label, self.name, out, post)
-
- # references
- self.refs = dict()
- # invisible to `self.nodes()`
- self.refs['pre'] = pre
- self.refs['post'] = post
- self.refs['syn'] = syn_cls
- self.refs['out'] = out
- # unify the access
- self.refs['comm'] = comm
-
- def update(self):
- x = _get_return(self.refs['syn'].return_info())
- current = self.comm(x)
- self.refs['out'].bind_cond(current)
- return current
-
-
-class ProjAlignPre1(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
-
- To simulate an E/I balanced network model:
-
- .. code-block:: python
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- ne, ni = 3200, 800
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- delay=0.1,
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
-
- def update(self, inp):
- self.E2E()
- self.E2I()
- self.I2E()
- self.I2I()
- self.E(inp)
- self.I(inp)
- return self.E.spike
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
-
- Args:
- pre: The pre-synaptic neuron group.
- syn: The synaptic dynamics.
- delay: The synaptic delay.
- comm: The synaptic communication.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- pre: DynamicalSystem,
- syn: JointType[DynamicalSystem, SupportAutoDelay],
- delay: Union[None, int, float],
- comm: DynamicalSystem,
- out: JointType[DynamicalSystem, BindCondData],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(pre, DynamicalSystem)
- check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay])
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(out, JointType[DynamicalSystem, BindCondData])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
-
- # synapse and delay initialization
- delay_cls = init_delay_by_return(syn.return_info())
- delay_cls.register_entry(self.name, delay)
- pre.add_aft_update(self.name, _AlignPre(syn, delay_cls))
-
- # output initialization
- add_inp_fun(out_label, self.name, out, post)
-
- # references
- self.refs = dict()
- # invisible to ``self.nodes()``
- self.refs['pre'] = pre
- self.refs['post'] = post
- self.refs['out'] = out
- self.refs['delay'] = delay_cls
- self.refs['syn'] = syn
- # unify the access
- self.refs['comm'] = comm
-
- def update(self, x=None):
- if x is None:
- x = self.refs['delay'].at(self.name)
- current = self.comm(x)
- self.refs['out'].bind_cond(current)
- return current
-
-
-class ProjAlignPre2(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
-
- To simulate an E/I balanced network model:
-
- .. code-block:: python
-
- class EINet(bp.DynSysGroup):
- def __init__(self):
- super().__init__()
- ne, ni = 3200, 800
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I,
- delay=0.1,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
-
- def update(self, inp):
- self.E2E()
- self.E2I()
- self.I2E()
- self.I2I()
- self.E(inp)
- self.I(inp)
- return self.E.spike
-
- model = EINet()
- indices = bm.arange(1000)
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
- bp.visualize.raster_plot(indices, spks, show=True)
-
-
- Args:
- pre: The pre-synaptic neuron group.
- delay: The synaptic delay.
- syn: The synaptic dynamics.
- comm: The synaptic communication.
- out: The synaptic output.
- post: The post-synaptic neuron group.
- name: str. The projection name.
- mode: Mode. The computing mode.
- """
-
- def __init__(
- self,
- pre: JointType[DynamicalSystem, SupportAutoDelay],
- delay: Union[None, int, float],
- syn: DynamicalSystem,
- comm: DynamicalSystem,
- out: JointType[DynamicalSystem, BindCondData],
- post: DynamicalSystem,
- out_label: Optional[str] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
- check.is_instance(syn, DynamicalSystem)
- check.is_instance(comm, DynamicalSystem)
- check.is_instance(out, JointType[DynamicalSystem, BindCondData])
- check.is_instance(post, DynamicalSystem)
- self.comm = comm
- self.syn = syn
-
- # delay initialization
- delay_cls = register_delay_by_return(pre)
- delay_cls.register_entry(self.name, delay)
-
- # output initialization
- add_inp_fun(out_label, self.name, out, post)
-
- # references
- self.refs = dict()
- # invisible to ``self.nodes()``
- self.refs['pre'] = pre
- self.refs['post'] = post
- self.refs['out'] = out
- self.refs['delay'] = delay_cls
- # unify the access
- self.refs['syn'] = syn
- self.refs['comm'] = comm
-
- def update(self):
- spk = self.refs['delay'].at(self.name)
- g = self.comm(self.syn(spk))
- self.refs['out'].bind_cond(g)
- return g
diff --git a/brainpy/_src/dyn/projections/base.py b/brainpy/_src/dyn/projections/base.py
new file mode 100644
index 000000000..44a2273a4
--- /dev/null
+++ b/brainpy/_src/dyn/projections/base.py
@@ -0,0 +1,12 @@
+from brainpy import math as bm
+from brainpy._src.mixin import ReturnInfo
+
+
+def _get_return(return_info):
+ if isinstance(return_info, bm.Variable):
+ return return_info.value
+ elif isinstance(return_info, ReturnInfo):
+ return return_info.get_data()
+ else:
+ raise NotImplementedError
+
diff --git a/brainpy/_src/dyn/projections/delta.py b/brainpy/_src/dyn/projections/delta.py
new file mode 100644
index 000000000..19e4938cb
--- /dev/null
+++ b/brainpy/_src/dyn/projections/delta.py
@@ -0,0 +1,210 @@
+from typing import Optional, Union
+
+from brainpy import math as bm, check
+from brainpy._src.delay import (delay_identifier, register_delay_by_return)
+from brainpy._src.dynsys import DynamicalSystem, Projection
+from brainpy._src.mixin import (JointType, SupportAutoDelay)
+
+__all__ = [
+ 'HalfProjDelta', 'FullProjDelta',
+]
+
+
+class _Delta:
+ def __init__(self):
+ self._cond = None
+
+ def bind_cond(self, cond):
+ self._cond = cond
+
+ def __call__(self, *args, **kwargs):
+ r = self._cond
+ return r
+
+
+class HalfProjDelta(Projection):
+ """Defining the half-part of the synaptic projection for the Delta synapse model.
+
+ The synaptic projection requires the input is the spiking data, otherwise
+ the synapse is not the Delta synapse model.
+
+ The ``half-part`` means that the model only includes ``comm`` -> ``syn`` -> ``out`` -> ``post``.
+ Therefore, the model's ``update`` function needs the manual providing of the spiking input.
+
+ **Model Descriptions**
+
+ .. math::
+
+ I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D)
+
+ where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength,
+ :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`,
+ :math:`C` the set of neurons connected to the post-synaptic neuron,
+ and :math:`D` the transmission delay of chemical synapses.
+ For simplicity, the rise and decay phases of post-synaptic currents are
+ omitted in this model.
+
+
+ **Code Examples**
+
+ .. code-block::
+
+ import brainpy as bp
+ import brainpy.math as bm
+
+ class Net(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+
+ self.pre = bp.dyn.PoissonGroup(10, 100.)
+ self.post = bp.dyn.LifRef(1)
+ self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post)
+
+ def update(self):
+ self.syn(self.pre())
+ self.post()
+ return self.post.V.value
+
+ net = Net()
+ indices = bm.arange(1000).to_numpy()
+ vs = bm.for_loop(net.step_run, indices, progress_bar=True)
+ bp.visualize.line_plot(indices, vs, show=True)
+
+ Args:
+ comm: DynamicalSystem. The synaptic communication.
+ post: DynamicalSystem. The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ comm: DynamicalSystem,
+ post: DynamicalSystem,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # output initialization
+ out = _Delta()
+ post.add_inp_fun(self.name, out, category='delta')
+
+ # references
+ self.refs = dict(post=post, out=out) # invisible to ``self.nodes()``
+ self.refs['comm'] = comm # unify the access
+
+ def update(self, x):
+ # call the communication
+ current = self.comm(x)
+ # bind the output
+ self.refs['out'].bind_cond(current)
+ # return the current, if needed
+ return current
+
+
+class FullProjDelta(Projection):
+ """Full-chain of the synaptic projection for the Delta synapse model.
+
+ The synaptic projection requires the input is the spiking data, otherwise
+ the synapse is not the Delta synapse model.
+
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
+ including ``pre`` -> ``delay`` -> ``comm`` -> ``post``.
+
+ **Model Descriptions**
+
+ .. math::
+
+ I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D)
+
+ where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength,
+ :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`,
+ :math:`C` the set of neurons connected to the post-synaptic neuron,
+ and :math:`D` the transmission delay of chemical synapses.
+ For simplicity, the rise and decay phases of post-synaptic currents are
+ omitted in this model.
+
+
+ **Code Examples**
+
+ .. code-block::
+
+ import brainpy as bp
+ import brainpy.math as bm
+
+
+ class Net(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+
+ self.pre = bp.dyn.PoissonGroup(10, 100.)
+ self.post = bp.dyn.LifRef(1)
+ self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post)
+
+ def update(self):
+ self.syn()
+ self.pre()
+ self.post()
+ return self.post.V.value
+
+
+ net = Net()
+ indices = bm.arange(1000).to_numpy()
+ vs = bm.for_loop(net.step_run, indices, progress_bar=True)
+ bp.visualize.line_plot(indices, vs, show=True)
+
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ delay: The synaptic delay.
+ comm: DynamicalSystem. The synaptic communication.
+ post: DynamicalSystem. The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
+ delay: Union[None, int, float],
+ comm: DynamicalSystem,
+ post: DynamicalSystem,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # delay initialization
+ delay_cls = register_delay_by_return(pre)
+ delay_cls.register_entry(self.name, delay)
+
+ # output initialization
+ out = _Delta()
+ post.add_inp_fun(self.name, out, category='delta')
+
+ # references
+ self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()``
+ self.refs['comm'] = comm # unify the access
+ self.refs['delay'] = pre.get_aft_update(delay_identifier)
+
+ def update(self):
+ # get delay
+ x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name)
+ # call the communication
+ current = self.comm(x)
+ # bind the output
+ self.refs['out'].bind_cond(current)
+ # return the current, if needed
+ return current
diff --git a/brainpy/_src/dyn/projections/inputs.py b/brainpy/_src/dyn/projections/inputs.py
index f0001988b..dd1e1e3df 100644
--- a/brainpy/_src/dyn/projections/inputs.py
+++ b/brainpy/_src/dyn/projections/inputs.py
@@ -1,96 +1,167 @@
-from typing import Optional, Any
+import numbers
+from typing import Any
+from typing import Union, Optional
-from brainpy import math as bm
+from brainpy import check, math as bm
+from brainpy._src.context import share
from brainpy._src.dynsys import Dynamic
+from brainpy._src.dynsys import Projection
from brainpy._src.mixin import SupportAutoDelay
from brainpy.types import Shape
__all__ = [
- 'InputVar',
+ 'InputVar',
+ 'PoissonInput',
]
class InputVar(Dynamic, SupportAutoDelay):
- """Define an input variable.
+ """Define an input variable.
- Example::
+ Example::
+
+ import brainpy as bp
- import brainpy as bp
-
- class Exponential(bp.Projection):
- def __init__(self, pre, post, prob, g_max, tau, E=0.):
- super().__init__()
- self.proj = bp.dyn.ProjAlignPostMg2(
- pre=pre,
- delay=None,
- comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
- syn=bp.dyn.Expon.desc(post.num, tau=tau),
- out=bp.dyn.COBA.desc(E=E),
- post=post,
- )
-
-
- class EINet(bp.DynSysGroup):
- def __init__(self, num_exc, num_inh, method='exp_auto'):
- super(EINet, self).__init__()
-
- # neurons
- pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.), method=method)
- self.E = bp.dyn.LifRef(num_exc, **pars)
- self.I = bp.dyn.LifRef(num_inh, **pars)
-
- # synapses
- w_e = 0.6 # excitatory synaptic weight
- w_i = 6.7 # inhibitory synaptic weight
-
- # Neurons connect to each other randomly with a connection probability of 2%
- self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.)
- self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.)
- self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.)
- self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.)
-
- # define input variables given to E/I populations
- self.Ein = bp.dyn.InputVar(self.E.varshape)
- self.Iin = bp.dyn.InputVar(self.I.varshape)
- self.E.add_inp_fun('', self.Ein)
- self.I.add_inp_fun('', self.Iin)
-
-
- net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method
- runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)])
- runner.run(100.)
-
- # visualization
- bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'],
- title='Spikes of Excitatory Neurons', show=True)
- bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'],
- title='Spikes of Inhibitory Neurons', show=True)
-
-
- """
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- sharding: Optional[Any] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- method: str = 'exp_auto'
- ):
- super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method)
-
- self.reset_state(self.mode)
-
- def reset_state(self, batch_or_mode=None, **kwargs):
- self.input = self.init_variable(bm.zeros, batch_or_mode)
-
- def update(self, *args, **kwargs):
- return self.input.value
-
- def return_info(self):
- return self.input
-
- def clear_input(self, *args, **kwargs):
- self.reset_state(self.mode)
+ class Exponential(bp.Projection):
+ def __init__(self, pre, post, prob, g_max, tau, E=0.):
+ super().__init__()
+ self.proj = bp.dyn.ProjAlignPostMg2(
+ pre=pre,
+ delay=None,
+ comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
+ syn=bp.dyn.Expon.desc(post.num, tau=tau),
+ out=bp.dyn.COBA.desc(E=E),
+ post=post,
+ )
+
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self, num_exc, num_inh, method='exp_auto'):
+ super(EINet, self).__init__()
+
+ # neurons
+ pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.), method=method)
+ self.E = bp.dyn.LifRef(num_exc, **pars)
+ self.I = bp.dyn.LifRef(num_inh, **pars)
+
+ # synapses
+ w_e = 0.6 # excitatory synaptic weight
+ w_i = 6.7 # inhibitory synaptic weight
+
+ # Neurons connect to each other randomly with a connection probability of 2%
+ self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.)
+ self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.)
+ self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.)
+ self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.)
+
+ # define input variables given to E/I populations
+ self.Ein = bp.dyn.InputVar(self.E.varshape)
+ self.Iin = bp.dyn.InputVar(self.I.varshape)
+ self.E.add_inp_fun('', self.Ein)
+ self.I.add_inp_fun('', self.Iin)
+
+
+ net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method
+ runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)])
+ runner.run(100.)
+
+ # visualization
+ bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'],
+ title='Spikes of Excitatory Neurons', show=True)
+ bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'],
+ title='Spikes of Inhibitory Neurons', show=True)
+
+
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ sharding: Optional[Any] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ method: str = 'exp_auto'
+ ):
+ super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method)
+
+ self.reset_state(self.mode)
+
+ def reset_state(self, batch_or_mode=None, **kwargs):
+ self.input = self.init_variable(bm.zeros, batch_or_mode)
+
+ def update(self, *args, **kwargs):
+ return self.input.value
+
+ def return_info(self):
+ return self.input
+
+ def clear_input(self, *args, **kwargs):
+ self.reset_state(self.mode)
+
+
+class PoissonInput(Projection):
+ """Poisson Input to the given :py:class:`~.Variable`.
+
+ Adds independent Poisson input to a target variable. For large
+ numbers of inputs, this is much more efficient than creating a
+ `PoissonGroup`. The synaptic events are generated randomly during the
+ simulation and are not preloaded and stored in memory. All the inputs must
+ target the same variable, have the same frequency and same synaptic weight.
+ All neurons in the target variable receive independent realizations of
+ Poisson spike trains.
+
+ Args:
+ target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`.
+ num_input: The number of inputs.
+ freq: The frequency of each of the inputs. Must be a scalar.
+ weight: The synaptic weight. Must be a scalar.
+ name: The target name.
+ mode: The computing mode.
+ """
+
+ def __init__(
+ self,
+ target_var: bm.Variable,
+ num_input: int,
+ freq: Union[int, float],
+ weight: Union[int, float],
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ if not isinstance(target_var, bm.Variable):
+ raise TypeError(f'"target_var" must be an instance of Variable. '
+ f'But we got {type(target_var)}: {target_var}')
+ self.target_var = target_var
+ self.num_input = check.is_integer(num_input, min_bound=1)
+ self.freq = check.is_float(freq, min_bound=0., allow_int=True)
+ self.weight = check.is_float(weight, allow_int=True)
+
+ def reset_state(self, *args, **kwargs):
+ pass
+
+ def update(self):
+ p = self.freq * share['dt'] / 1e3
+ a = self.num_input * p
+ b = self.num_input * (1 - p)
+
+ if isinstance(share['dt'], numbers.Number): # dt is not traced
+ if (a > 5) and (b > 5):
+ inp = bm.random.normal(a, b * p, self.target_var.shape)
+ else:
+ inp = bm.random.binomial(self.num_input, p, self.target_var.shape)
+
+ else: # dt is traced
+ inp = bm.cond((a > 5) * (b > 5),
+ lambda: bm.random.normal(a, b * p, self.target_var.shape),
+ lambda: bm.random.binomial(self.num_input, p, self.target_var.shape))
+
+ # inp = bm.sharding.partition(inp, self.target_var.sharding)
+ self.target_var += inp * self.weight
+
+ def __repr__(self):
+ return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})'
diff --git a/brainpy/_src/dyn/projections/others.py b/brainpy/_src/dyn/projections/others.py
deleted file mode 100644
index 72a77298f..000000000
--- a/brainpy/_src/dyn/projections/others.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import numbers
-import warnings
-from typing import Union, Optional
-
-from brainpy import check, math as bm
-from brainpy._src.context import share
-from brainpy._src.dynsys import Projection
-
-__all__ = [
- 'PoissonInput',
-]
-
-
-class PoissonInput(Projection):
- """Poisson Input to the given :py:class:`~.Variable`.
-
- Adds independent Poisson input to a target variable. For large
- numbers of inputs, this is much more efficient than creating a
- `PoissonGroup`. The synaptic events are generated randomly during the
- simulation and are not preloaded and stored in memory. All the inputs must
- target the same variable, have the same frequency and same synaptic weight.
- All neurons in the target variable receive independent realizations of
- Poisson spike trains.
-
- Args:
- target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`.
- num_input: The number of inputs.
- freq: The frequency of each of the inputs. Must be a scalar.
- weight: The synaptic weight. Must be a scalar.
- name: The target name.
- mode: The computing mode.
- """
-
- def __init__(
- self,
- target_var: bm.Variable,
- num_input: int,
- freq: Union[int, float],
- weight: Union[int, float],
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- seed=None
- ):
- super().__init__(name=name, mode=mode)
-
- if seed is not None:
- warnings.warn('')
-
- if not isinstance(target_var, bm.Variable):
- raise TypeError(f'"target_var" must be an instance of Variable. '
- f'But we got {type(target_var)}: {target_var}')
- self.target_var = target_var
- self.num_input = check.is_integer(num_input, min_bound=1)
- self.freq = check.is_float(freq, min_bound=0., allow_int=True)
- self.weight = check.is_float(weight, allow_int=True)
-
- def reset_state(self, *args, **kwargs):
- pass
-
- def update(self):
- p = self.freq * share['dt'] / 1e3
- a = self.num_input * p
- b = self.num_input * (1 - p)
-
- if isinstance(share['dt'], numbers.Number): # dt is not traced
- if (a > 5) and (b > 5):
- inp = bm.random.normal(a, b * p, self.target_var.shape)
- else:
- inp = bm.random.binomial(self.num_input, p, self.target_var.shape)
-
- else: # dt is traced
- inp = bm.cond((a > 5) * (b > 5),
- lambda: bm.random.normal(a, b * p, self.target_var.shape),
- lambda: bm.random.binomial(self.num_input, p, self.target_var.shape),
- ())
-
- # inp = bm.sharding.partition(inp, self.target_var.sharding)
- self.target_var += inp * self.weight
-
- def __repr__(self):
- return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})'
diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py
index 3ee6f4fef..439b6eb6c 100644
--- a/brainpy/_src/dyn/projections/plasticity.py
+++ b/brainpy/_src/dyn/projections/plasticity.py
@@ -7,8 +7,9 @@
from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay,
BindCondData, AlignPost, SupportSTDP)
from brainpy.types import ArrayType
-from .aligns import (_get_return, align_post_add_bef_update,
- align_pre2_add_bef_update, add_inp_fun)
+from .align_post import (align_post_add_bef_update, )
+from .align_pre import (align_pre2_add_bef_update, )
+from .utils import (_get_return, )
__all__ = [
'STDP_Song2000',
@@ -49,8 +50,8 @@ class STDP_Song2000(Projection):
\begin{aligned}
\frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
- \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\
- \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\
+ \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
+ \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
\end{aligned}
where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
@@ -64,8 +65,8 @@ class STDP_Song2000(Projection):
class STDPNet(bp.DynamicalSystem):
def __init__(self, num_pre, num_post):
super().__init__()
- self.pre = bp.dyn.LifRef(num_pre, name='neu1')
- self.post = bp.dyn.LifRef(num_post, name='neu2')
+ self.pre = bp.dyn.LifRef(num_pre)
+ self.post = bp.dyn.LifRef(num_post)
self.syn = bp.dyn.STDP_Song2000(
pre=self.pre,
delay=1.,
@@ -165,7 +166,7 @@ def __init__(
else:
syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre')
out_cls = out()
- add_inp_fun(out_label, self.name, out_cls, post)
+ post.add_inp_fun(self.name, out_cls, label=out_label)
# references
self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
@@ -188,6 +189,12 @@ def __init__(
self.A1 = A1
self.A2 = A2
+ pre = property(lambda self: self.refs['pre'])
+ post = property(lambda self: self.refs['post'])
+ syn = property(lambda self: self.refs['syn'])
+ delay = property(lambda self: self.refs['delay'])
+ out = property(lambda self: self.refs['out'])
+
def update(self):
# pre-synaptic spikes
pre_spike = self.refs['delay'].at(self.name) # spike
@@ -219,3 +226,193 @@ def update(self):
return current
+# class PairedSTDP(Projection):
+# r"""Paired spike-time-dependent plasticity model.
+#
+# This model filters the synaptic currents according to the variables: :math:`w`.
+#
+# .. math::
+#
+# I_{syn}^+(t) = I_{syn}^-(t) * w
+#
+# where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
+# and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse,
+# the conductance of the synapse will increase w.
+#
+# The dynamics of :math:`w` is governed by the following equation:
+#
+# .. math::
+#
+# \begin{aligned}
+# \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
+# \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
+# \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
+# \end{aligned}
+#
+# where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
+# of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike.
+#
+# Here is an example of the usage of this class::
+#
+# import brainpy as bp
+# import brainpy.math as bm
+#
+# class STDPNet(bp.DynamicalSystem):
+# def __init__(self, num_pre, num_post):
+# super().__init__()
+# self.pre = bp.dyn.LifRef(num_pre)
+# self.post = bp.dyn.LifRef(num_post)
+# self.syn = bp.dyn.STDP_Song2000(
+# pre=self.pre,
+# delay=1.,
+# comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
+# weight=bp.init.Uniform(max_val=0.1)),
+# syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
+# out=bp.dyn.COBA.desc(E=0.),
+# post=self.post,
+# tau_s=16.8,
+# tau_t=33.7,
+# A1=0.96,
+# A2=0.53,
+# )
+#
+# def update(self, I_pre, I_post):
+# self.syn()
+# self.pre(I_pre)
+# self.post(I_post)
+# conductance = self.syn.refs['syn'].g
+# Apre = self.syn.refs['pre_trace'].g
+# Apost = self.syn.refs['post_trace'].g
+# current = self.post.sum_inputs(self.post.V)
+# return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
+#
+# duration = 300.
+# I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+# [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
+# I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+# [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
+#
+# net = STDPNet(1, 1)
+# def run(i, I_pre, I_post):
+# pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
+# return pre_spike, post_spike, g, Apre, Apost, current, W
+#
+# indices = bm.arange(0, duration, bm.dt)
+# pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post])
+#
+# Args:
+# tau_s: float. The time constant of :math:`A_{pre}`.
+# tau_t: float. The time constant of :math:`A_{post}`.
+# A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value.
+# A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value.
+# W_max: float. The maximum weight.
+# W_min: float. The minimum weight.
+# pre: DynamicalSystem. The pre-synaptic neuron group.
+# delay: int, float. The pre spike delay length. (ms)
+# syn: DynamicalSystem. The synapse model.
+# comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers.
+# out: DynamicalSystem. The synaptic current output models.
+# post: DynamicalSystem. The post-synaptic neuron group.
+# out_label: str. The output label.
+# name: str. The model name.
+# """
+#
+# def __init__(
+# self,
+# pre: JointType[DynamicalSystem, SupportAutoDelay],
+# delay: Union[None, int, float],
+# syn: ParamDescriber[DynamicalSystem],
+# comm: JointType[DynamicalSystem, SupportSTDP],
+# out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
+# post: DynamicalSystem,
+# # synapse parameters
+# tau_s: float = 16.8,
+# tau_t: float = 33.7,
+# lambda_: float = 0.96,
+# alpha: float = 0.53,
+# mu: float = 0.53,
+# W_max: Optional[float] = None,
+# W_min: Optional[float] = None,
+# # others
+# out_label: Optional[str] = None,
+# name: Optional[str] = None,
+# mode: Optional[bm.Mode] = None,
+# ):
+# super().__init__(name=name, mode=mode)
+#
+# # synaptic models
+# check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+# check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
+# check.is_instance(syn, ParamDescriber[DynamicalSystem])
+# check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
+# check.is_instance(post, DynamicalSystem)
+# self.pre_num = pre.num
+# self.post_num = post.num
+# self.comm = comm
+# self._is_align_post = issubclass(syn.cls, AlignPost)
+#
+# # delay initialization
+# delay_cls = register_delay_by_return(pre)
+# delay_cls.register_entry(self.name, delay)
+#
+# # synapse and output initialization
+# if self._is_align_post:
+# syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post,
+# proj_name=self.name)
+# else:
+# syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre')
+# out_cls = out()
+# add_inp_fun(out_label, self.name, out_cls, post)
+#
+# # references
+# self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
+# self.refs['delay'] = delay_cls
+# self.refs['syn'] = syn_cls # invisible to ``self.node()``
+# self.refs['out'] = out_cls # invisible to ``self.node()``
+# self.refs['comm'] = comm
+#
+# # tracing pre-synaptic spikes using Exponential model
+# self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s))
+#
+# # tracing post-synaptic spikes using Exponential model
+# self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t))
+#
+# # synapse parameters
+# self.W_max = W_max
+# self.W_min = W_min
+# self.tau_s = tau_s
+# self.tau_t = tau_t
+# self.A1 = A1
+# self.A2 = A2
+#
+# def update(self):
+# # pre-synaptic spikes
+# pre_spike = self.refs['delay'].at(self.name) # spike
+# # pre-synaptic variables
+# if self._is_align_post:
+# # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance
+# x = pre_spike
+# else:
+# # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance
+# x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable
+#
+# # post spikes
+# if not hasattr(self.refs['post'], 'spike'):
+# raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.')
+# post_spike = self.refs['post'].spike
+#
+# # weight updates
+# Apost = self.refs['post_trace'].g
+# self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max)
+# Apre = self.refs['pre_trace'].g
+# self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max)
+#
+# # synaptic currents
+# current = self.comm(x)
+# if self._is_align_post:
+# self.refs['syn'].add_current(current) # synapse post current
+# else:
+# self.refs['out'].bind_cond(current) # align pre
+# return current
+
+
diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py
index a4173c7ba..b8884f327 100644
--- a/brainpy/_src/dyn/projections/tests/test_STDP.py
+++ b/brainpy/_src/dyn/projections/tests/test_STDP.py
@@ -86,7 +86,7 @@ def update(self, I_pre, I_post):
conductance = self.syn.refs['syn'].g
Apre = self.syn.refs['pre_trace'].g
Apost = self.syn.refs['post_trace'].g
- current = self.post.sum_inputs(self.post.V)
+ current = self.post.sum_current_inputs(self.post.V)
if comm_method == 'dense':
w = self.syn.comm.W.flatten()
else:
diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py
index 32b072e5a..90500a26f 100644
--- a/brainpy/_src/dyn/projections/tests/test_aligns.py
+++ b/brainpy/_src/dyn/projections/tests/test_aligns.py
@@ -19,7 +19,7 @@ def __init__(self, scale=1., inp=20., delay=None):
prob = 80 / (4000 * scale)
- self.E2I = bp.dyn.ProjAlignPreMg1(
+ self.E2I = bp.dyn.FullProjAlignPreSDMg(
pre=self.E,
syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.),
delay=delay,
@@ -27,7 +27,7 @@ def __init__(self, scale=1., inp=20., delay=None):
out=bp.dyn.COBA(E=0.),
post=self.I,
)
- self.E2E = bp.dyn.ProjAlignPreMg1(
+ self.E2E = bp.dyn.FullProjAlignPreSDMg(
pre=self.E,
syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.),
delay=delay,
@@ -35,7 +35,7 @@ def __init__(self, scale=1., inp=20., delay=None):
out=bp.dyn.COBA(E=0.),
post=self.E,
)
- self.I2E = bp.dyn.ProjAlignPreMg1(
+ self.I2E = bp.dyn.FullProjAlignPreSDMg(
pre=self.I,
syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.),
delay=delay,
@@ -43,7 +43,7 @@ def __init__(self, scale=1., inp=20., delay=None):
out=bp.dyn.COBA(E=-80.),
post=self.E,
)
- self.I2I = bp.dyn.ProjAlignPreMg1(
+ self.I2I = bp.dyn.FullProjAlignPreSDMg(
pre=self.I,
syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.),
delay=delay,
@@ -90,7 +90,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None):
prob = 80 / (4000 * scale)
- self.E2E = bp.dyn.ProjAlignPostMg2(
+ self.E2E = bp.dyn.FullProjAlignPostMg(
pre=self.E,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6),
@@ -98,7 +98,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None):
out=bp.dyn.COBA.desc(E=0.),
post=self.E,
)
- self.E2I = bp.dyn.ProjAlignPostMg2(
+ self.E2I = bp.dyn.FullProjAlignPostMg(
pre=self.E,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6),
@@ -106,7 +106,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None):
out=bp.dyn.COBA.desc(E=0.),
post=self.I,
)
- self.I2E = bp.dyn.ProjAlignPostMg2(
+ self.I2E = bp.dyn.FullProjAlignPostMg(
pre=self.I,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7),
@@ -114,7 +114,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None):
out=bp.dyn.COBA.desc(E=-80.),
post=self.E,
)
- self.I2I = bp.dyn.ProjAlignPostMg2(
+ self.I2I = bp.dyn.FullProjAlignPostMg(
pre=self.I,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7),
@@ -163,14 +163,14 @@ def __init__(self, scale=1.):
self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
- self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6),
- syn=bp.dyn.Expon(size=num, tau=5.),
- out=bp.dyn.COBA(E=0.),
- post=self.N)
- self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7),
- syn=bp.dyn.Expon(size=num, tau=10.),
- out=bp.dyn.COBA(E=-80.),
- post=self.N)
+ self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6),
+ syn=bp.dyn.Expon(size=num, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.N)
+ self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7),
+ syn=bp.dyn.Expon(size=num, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.N)
def update(self, input):
spk = self.delay.at('I')
@@ -198,30 +198,30 @@ def __init__(self, scale, delay=None):
V_initializer=bp.init.Normal(-55., 2.))
self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPost2(pre=self.E,
- delay=delay,
- comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6),
- syn=bp.dyn.Expon(size=ne, tau=5.),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPost2(pre=self.E,
- delay=delay,
- comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6),
- syn=bp.dyn.Expon(size=ni, tau=5.),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPost2(pre=self.I,
- delay=delay,
- comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7),
- syn=bp.dyn.Expon(size=ne, tau=10.),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPost2(pre=self.I,
- delay=delay,
- comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7),
- syn=bp.dyn.Expon(size=ni, tau=10.),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
+ self.E2E = bp.dyn.FullProjAlignPost(pre=self.E,
+ delay=delay,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6),
+ syn=bp.dyn.Expon(size=ne, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPost(pre=self.E,
+ delay=delay,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6),
+ syn=bp.dyn.Expon(size=ni, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPost(pre=self.I,
+ delay=delay,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7),
+ syn=bp.dyn.Expon(size=ne, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPost(pre=self.I,
+ delay=delay,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7),
+ syn=bp.dyn.Expon(size=ni, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
def update(self, inp):
self.E2E()
@@ -292,30 +292,30 @@ def __init__(self, scale=1., delay=None):
V_initializer=bp.init.Normal(-55., 2.))
self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- delay=delay,
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- delay=delay,
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- delay=delay,
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- delay=delay,
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
+ self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=delay,
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=delay,
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=delay,
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=delay,
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
def update(self, inp):
self.E2E()
@@ -350,30 +350,30 @@ def __init__(self, scale=1., delay=None):
V_initializer=bp.init.Normal(-55., 2.))
self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
- self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E,
- delay=delay,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.E)
- self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E,
- delay=delay,
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6),
- out=bp.dyn.COBA(E=0.),
- post=self.I)
- self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I,
- delay=delay,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.E)
- self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I,
- delay=delay,
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7),
- out=bp.dyn.COBA(E=-80.),
- post=self.I)
+ self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
+ delay=delay,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
+ delay=delay,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
+ delay=delay,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
+ delay=delay,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
def update(self, inp):
self.E2E()
diff --git a/brainpy/_src/dyn/projections/tests/test_delta.py b/brainpy/_src/dyn/projections/tests/test_delta.py
new file mode 100644
index 000000000..f4d21b643
--- /dev/null
+++ b/brainpy/_src/dyn/projections/tests/test_delta.py
@@ -0,0 +1,51 @@
+import matplotlib.pyplot as plt
+
+import brainpy as bp
+import brainpy.math as bm
+
+
+class NetForHalfProj(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+
+ self.pre = bp.dyn.PoissonGroup(10, 100.)
+ self.post = bp.dyn.LifRef(1)
+ self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post)
+
+ def update(self):
+ self.syn(self.pre())
+ self.post()
+ return self.post.V.value
+
+
+def test1():
+ net = NetForHalfProj()
+ indices = bm.arange(1000).to_numpy()
+ vs = bm.for_loop(net.step_run, indices, progress_bar=True)
+ bp.visualize.line_plot(indices, vs, show=False)
+ plt.close('all')
+
+
+class NetForFullProj(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+
+ self.pre = bp.dyn.PoissonGroup(10, 100.)
+ self.post = bp.dyn.LifRef(1)
+ self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post)
+
+ def update(self):
+ self.syn()
+ self.pre()
+ self.post()
+ return self.post.V.value
+
+
+def test2():
+ net = NetForFullProj()
+ indices = bm.arange(1000).to_numpy()
+ vs = bm.for_loop(net.step_run, indices, progress_bar=True)
+ bp.visualize.line_plot(indices, vs, show=False)
+ plt.close('all')
+
+
diff --git a/brainpy/_src/dyn/projections/utils.py b/brainpy/_src/dyn/projections/utils.py
new file mode 100644
index 000000000..44a2273a4
--- /dev/null
+++ b/brainpy/_src/dyn/projections/utils.py
@@ -0,0 +1,12 @@
+from brainpy import math as bm
+from brainpy._src.mixin import ReturnInfo
+
+
+def _get_return(return_info):
+ if isinstance(return_info, bm.Variable):
+ return return_info.value
+ elif isinstance(return_info, ReturnInfo):
+ return return_info.get_data()
+ else:
+ raise NotImplementedError
+
diff --git a/brainpy/_src/dyn/projections/vanilla.py b/brainpy/_src/dyn/projections/vanilla.py
new file mode 100644
index 000000000..15773d231
--- /dev/null
+++ b/brainpy/_src/dyn/projections/vanilla.py
@@ -0,0 +1,83 @@
+from typing import Optional
+
+from brainpy import math as bm, check
+from brainpy._src.dynsys import DynamicalSystem, Projection
+from brainpy._src.mixin import (JointType, BindCondData)
+
+__all__ = [
+ 'VanillaProj',
+]
+
+
+class VanillaProj(Projection):
+ """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group.
+
+ **Code Examples**
+
+ To simulate an E/I balanced network model:
+
+ .. code-block::
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
+ self.syn1 = bp.dyn.Expon(size=3200, tau=5.)
+ self.syn2 = bp.dyn.Expon(size=800, tau=10.)
+ self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.N)
+ self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.N)
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(self.syn1(spk[:3200]))
+ self.I(self.syn2(spk[3200:]))
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ comm: The synaptic communication.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ comm: DynamicalSystem,
+ out: JointType[DynamicalSystem, BindCondData],
+ post: DynamicalSystem,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, DynamicalSystem)
+ self.comm = comm
+
+ # output initialization
+ post.add_inp_fun(self.name, out)
+
+ # references
+ self.refs = dict(post=post, out=out) # invisible to ``self.nodes()``
+ self.refs['comm'] = comm # unify the access
+
+ def update(self, x):
+ current = self.comm(x)
+ self.refs['out'].bind_cond(current)
+ return current
diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py
index 4a6b9ddb6..cdc1912d7 100644
--- a/brainpy/_src/dyn/synapses/abstract_models.py
+++ b/brainpy/_src/dyn/synapses/abstract_models.py
@@ -2,7 +2,8 @@
from brainpy import math as bm
from brainpy._src.context import share
-from brainpy._src.dyn._docs import pneu_doc
+from brainpy._src.initialize import parameter
+from brainpy._src.dyn import _docs
from brainpy._src.dyn.base import SynDyn
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
@@ -10,7 +11,6 @@
from brainpy.types import ArrayType
__all__ = [
- 'Delta',
'Expon',
'DualExpon',
'DualExponV2',
@@ -21,94 +21,10 @@
]
-class Delta(SynDyn, AlignPost):
- r"""Delta synapse model.
-
- **Model Descriptions**
-
- The single exponential decay synapse model assumes the release of neurotransmitter,
- its diffusion across the cleft, the receptor binding, and channel opening all happen
- very quickly, so that the channels instantaneously jump from the closed to the open state.
- Therefore, its expression is given by
-
- .. math::
-
- g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau}
-
- where :math:`\tau_{delay}` is the time constant of the synaptic state decay,
- :math:`t_0` is the time of the pre-synaptic spike,
- :math:`g_{\mathrm{max}}` is the maximal conductance.
-
- Accordingly, the differential form of the exponential synapse is given by
-
- .. math::
-
- \begin{aligned}
- & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
- \end{aligned}
-
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
-
- """
-
- def __init__(
- self,
- size: Union[int, Sequence[int]],
- keep_size: bool = False,
- sharding: Optional[Sequence[str]] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name,
- mode=mode,
- size=size,
- keep_size=keep_size,
- sharding=sharding)
-
- self.reset_state(self.mode)
-
- def reset_state(self, batch_or_mode=None, **kwargs):
- self.g = self.init_variable(bm.zeros, batch_or_mode)
-
- def update(self, x=None):
- if x is not None:
- self.g.value += x
- return self.g.value
-
- def add_current(self, x):
- self.g.value += x
-
- def return_info(self):
- return self.g
-
-
class Expon(SynDyn, AlignPost):
r"""Exponential decay synapse model.
- **Model Descriptions**
-
- The single exponential decay synapse model assumes the release of neurotransmitter,
- its diffusion across the cleft, the receptor binding, and channel opening all happen
- very quickly, so that the channels instantaneously jump from the closed to the open state.
- Therefore, its expression is given by
-
- .. math::
-
- g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau}
-
- where :math:`\tau_{delay}` is the time constant of the synaptic state decay,
- :math:`t_0` is the time of the pre-synaptic spike,
- :math:`g_{\mathrm{max}}` is the maximal conductance.
-
- Accordingly, the differential form of the exponential synapse is given by
-
- .. math::
-
- \begin{aligned}
- & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
- \end{aligned}
+ %s
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
@@ -170,11 +86,6 @@ def __init__(self, pre, post, delay, prob, g_max, tau, E):
)
-
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
-
Args:
tau: float. The time constant of decay. [ms]
%s
@@ -226,36 +137,21 @@ def return_info(self):
return self.g
-Expon.__doc__ = Expon.__doc__ % (pneu_doc,)
+Expon.__doc__ = Expon.__doc__ % (_docs.exp_syn_doc, _docs.pneu_doc,)
-class DualExpon(SynDyn):
- r"""Dual exponential synapse model.
+def _format_dual_exp_A(self, A):
+ A = parameter(A, sizes=self.varshape, allow_none=True, sharding=self.sharding)
+ if A is None:
+ A = (self.tau_decay / (self.tau_decay - self.tau_rise) *
+ bm.float_power(self.tau_rise / self.tau_decay, self.tau_rise / (self.tau_rise - self.tau_decay)))
+ return A
- **Model Descriptions**
-
- The dual exponential synapse model [1]_, also named as *difference of two exponentials* model,
- is given by:
-
- .. math::
- g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{
- \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right)
- -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right)
-
- where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2`
- is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic
- spike, :math:`g_{\mathrm{max}}` is the maximal conductance.
-
- However, in practice, this formula is hard to implement. The equivalent solution is
- two coupled linear differential equations [2]_:
-
- .. math::
+class DualExpon(SynDyn):
+ r"""Dual exponential synapse model.
- \begin{aligned}
- &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\
- &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right),
- \end{aligned}
+ %s
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
@@ -267,11 +163,9 @@ class DualExpon(SynDyn):
import matplotlib.pyplot as plt
-
class DualExpSparseCOBA(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
super().__init__()
-
self.proj = bp.dyn.ProjAlignPreMg2(
pre=pre,
delay=delay,
@@ -281,7 +175,6 @@ def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
post=post,
)
-
class SimpleNet(bp.DynSysGroup):
def __init__(self, syn_cls, E=0.):
super().__init__()
@@ -317,15 +210,16 @@ def update(self):
plt.title('Post V')
plt.show()
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
- .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational
- Modeling Methods for Neuroscientists.
+ See Also:
+ DualExponV2
+
+ .. note::
+
+ The implementation of this model can only be used in ``AlignPre`` projections.
+ One the contrary, to seek the ``AlignPost`` projection, please use ``DualExponV2``.
Args:
- tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms]
- tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms]
+ %s
%s
"""
@@ -341,6 +235,7 @@ def __init__(
# synapse parameters
tau_decay: Union[float, ArrayType, Callable] = 10.0,
tau_rise: Union[float, ArrayType, Callable] = 1.,
+ A: Optional[Union[float, ArrayType, Callable]] = None,
):
super().__init__(name=name,
mode=mode,
@@ -351,6 +246,8 @@ def __init__(
# parameters
self.tau_rise = self.init_param(tau_rise)
self.tau_decay = self.init_param(tau_decay)
+ A = _format_dual_exp_A(self, A)
+ self.a = (self.tau_decay - self.tau_rise) / self.tau_rise / self.tau_decay * A
# integrator
self.integral = odeint(JointEq(self.dg, self.dh), method=method)
@@ -368,33 +265,28 @@ def dg(self, g, t, h):
return -g / self.tau_decay + h
def update(self, x):
+ # x: the pre-synaptic spikes
+
# update synaptic variables
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
- self.h += x
+ self.h += self.a * x
return self.g.value
def return_info(self):
return self.g
-DualExpon.__doc__ = DualExpon.__doc__ % (pneu_doc,)
+DualExpon.__doc__ = DualExpon.__doc__ % (_docs.dual_exp_syn_doc, _docs.pneu_doc, _docs.dual_exp_args)
class DualExponV2(SynDyn, AlignPost):
r"""Dual exponential synapse model.
- The dual exponential synapse model [1]_, also named as *difference of two exponentials* model,
- is given by:
-
- .. math::
+ %s
- g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{
- \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right)
- -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right)
+ .. note::
- where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2`
- is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic
- spike, :math:`g_{\mathrm{max}}` is the maximal conductance.
+ Different from ``DualExpon``, this model can be used in both modes of ``AlignPre`` and ``AlignPost`` projections.
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
@@ -438,9 +330,6 @@ def update(self):
current = self.post.sum_inputs(self.post.V)
return conductance, current, self.post.V
-
-
-
indices = np.arange(1000) # 100 ms, dt= 0.1 ms
net = SimpleNet(DualExponV2SparseCOBAPost, E=0.)
conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True)
@@ -457,7 +346,6 @@ def update(self):
plt.title('Post V')
plt.show()
-
Moreover, it can also be used with interface ``ProjAlignPostMg2``:
.. code-block:: python
@@ -475,17 +363,11 @@ def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
post=post,
)
-
-
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
- .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational
- Modeling Methods for Neuroscientists.
+ See Also:
+ DualExpon
Args:
- tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms]
- tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms]
+ %s
%s
"""
@@ -501,6 +383,7 @@ def __init__(
# synapse parameters
tau_decay: Union[float, ArrayType, Callable] = 10.0,
tau_rise: Union[float, ArrayType, Callable] = 1.,
+ A: Optional[Union[float, ArrayType, Callable]] = None,
):
super().__init__(name=name,
mode=mode,
@@ -511,7 +394,7 @@ def __init__(
# parameters
self.tau_rise = self.init_param(tau_rise)
self.tau_decay = self.init_param(tau_decay)
- self.coeff = self.tau_rise * self.tau_decay / (self.tau_decay - self.tau_rise)
+ self.a = _format_dual_exp_A(self, A)
# integrator
self.integral = odeint(lambda g, t, tau: -g / tau, method=method)
@@ -527,7 +410,7 @@ def update(self, x=None):
self.g_decay.value = self.integral(self.g_decay.value, share['t'], self.tau_decay, share['dt'])
if x is not None:
self.add_current(x)
- return self.coeff * (self.g_decay - self.g_rise)
+ return self.a * (self.g_decay - self.g_rise)
def add_current(self, inp):
self.g_rise += inp
@@ -535,32 +418,16 @@ def add_current(self, inp):
def return_info(self):
return ReturnInfo(self.varshape, self.sharding, self.mode,
- lambda shape: self.coeff * (self.g_decay - self.g_rise))
+ lambda shape: self.a * (self.g_decay - self.g_rise))
-DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,)
+DualExponV2.__doc__ = DualExponV2.__doc__ % (_docs.dual_exp_syn_doc, _docs.pneu_doc, _docs.dual_exp_args,)
-class Alpha(DualExpon):
+class Alpha(SynDyn):
r"""Alpha synapse model.
- **Model Descriptions**
-
- The analytical expression of alpha synapse is given by:
-
- .. math::
-
- g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right).
-
- While, this equation is hard to implement. So, let's try to convert it into the
- differential forms:
-
- .. math::
-
- \begin{aligned}
- &\frac{d g}{d t}=-\frac{g}{\tau}+h \\
- &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right)
- \end{aligned}
+ %s
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
@@ -623,17 +490,9 @@ def update(self):
plt.show()
-
-
-
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
-
Args:
%s
tau_decay: float, ArrayType, Callable. The time constant [ms] of the synaptic decay phase.
- The name of this synaptic projection.
"""
def __init__(
@@ -649,9 +508,6 @@ def __init__(
tau_decay: Union[float, ArrayType, Callable] = 10.0,
):
super().__init__(
- tau_decay=tau_decay,
- tau_rise=tau_decay,
- method=method,
name=name,
mode=mode,
size=size,
@@ -659,8 +515,35 @@ def __init__(
sharding=sharding
)
+ # parameters
+ self.tau_decay = self.init_param(tau_decay)
-Alpha.__doc__ = Alpha.__doc__ % (pneu_doc,)
+ # integrator
+ self.integral = odeint(JointEq(self.dg, self.dh), method=method)
+
+ self.reset_state(self.mode)
+
+ def reset_state(self, batch_or_mode=None, **kwargs):
+ self.h = self.init_variable(bm.zeros, batch_or_mode)
+ self.g = self.init_variable(bm.zeros, batch_or_mode)
+
+ def dh(self, h, t):
+ return -h / self.tau_decay
+
+ def dg(self, g, t, h):
+ return -g / self.tau_decay + h / self.tau_decay
+
+ def update(self, x):
+ # update synaptic variables
+ self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
+ self.h += x
+ return self.g.value
+
+ def return_info(self):
+ return self.g
+
+
+Alpha.__doc__ = Alpha.__doc__ % (_docs.alpha_syn_doc, _docs.pneu_doc,)
class NMDA(SynDyn):
@@ -845,30 +728,13 @@ def return_info(self):
return self.g
-NMDA.__doc__ = NMDA.__doc__ % (pneu_doc,)
+NMDA.__doc__ = NMDA.__doc__ % (_docs.pneu_doc,)
class STD(SynDyn):
r"""Synaptic output with short-term depression.
- This model filters the synaptic current by the following equation:
-
- .. math::
-
- I_{syn}^+(t) = I_{syn}^-(t) * x
-
- where :math:`x` is the normalized variable between 0 and 1, and
- :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
- and after STD filtering.
-
- Moreover, :math:`x` is updated according to the dynamics of:
-
- .. math::
-
- \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike})
-
- where :math:`U` is the fraction of resources used per action potential,
- :math:`\tau` is the time constant of recovery of the synaptic vesicles.
+ %s
Args:
tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles.
@@ -924,36 +790,13 @@ def return_info(self):
return self.x
-STD.__doc__ = STD.__doc__ % (pneu_doc,)
+STD.__doc__ = STD.__doc__ % (_docs.std_doc, _docs.pneu_doc,)
class STP(SynDyn):
r"""Synaptic output with short-term plasticity.
- This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`.
-
- .. math::
-
- I_{syn}^+(t) = I_{syn}^-(t) * x * u
-
- where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
- and after STP filtering, :math:`x` denotes the fraction of resources that remain available
- after neurotransmitter depletion, and :math:`u` represents the fraction of available
- resources ready for use (release probability).
-
- The dynamics of :math:`u` and :math:`x` are governed by
-
- .. math::
-
- \begin{aligned}
- \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\
- \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\
- \tag{1}\end{aligned}
-
- where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment
- of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding
- variables just before the arrival of the spike, and :math:`u^+`
- refers to the moment just after the spike.
+ %s
Args:
tau_f: float, ArrayType, Callable. The time constant of short-term facilitation.
@@ -1030,4 +873,4 @@ def return_info(self):
lambda shape: self.u * self.x)
-STP.__doc__ = STP.__doc__ % (pneu_doc,)
\ No newline at end of file
+STP.__doc__ = STP.__doc__ % (_docs.stp_doc, _docs.pneu_doc,)
diff --git a/brainpy/_src/dyn/synapses/delay_couplings.py b/brainpy/_src/dyn/synapses/delay_couplings.py
index ef43139da..8a848e646 100644
--- a/brainpy/_src/dyn/synapses/delay_couplings.py
+++ b/brainpy/_src/dyn/synapses/delay_couplings.py
@@ -64,7 +64,7 @@ def __init__(
self.output_var = var_to_output
# Connection matrix
- self.conn_mat = bm.asarray(conn_mat)
+ self.conn_mat = conn_mat
if self.conn_mat.shape != required_shape:
raise ValueError(f'we expect the structural connection matrix has the shape of '
f'(pre.num, post.num), i.e., {required_shape}, '
diff --git a/brainpy/_src/dyn/synapses/tests/test_abstract_models.py b/brainpy/_src/dyn/synapses/tests/test_abstract_models.py
new file mode 100644
index 000000000..ca028e2e4
--- /dev/null
+++ b/brainpy/_src/dyn/synapses/tests/test_abstract_models.py
@@ -0,0 +1,87 @@
+import unittest
+
+import matplotlib.pyplot as plt
+
+import brainpy as bp
+import brainpy.math as bm
+
+show = False
+
+
+class TestDualExpon(unittest.TestCase):
+ def test_dual_expon(self):
+ # bm.set(dt=0.01)
+
+ class Net(bp.DynSysGroup):
+ def __init__(self, tau_r, tau_d, n_spk):
+ super().__init__()
+
+ self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk))
+ self.proj = bp.dyn.DualExpon(1, tau_rise=tau_r, tau_decay=tau_d)
+
+ def update(self):
+ self.proj(self.inp())
+ return self.proj.h.value, self.proj.g.value
+
+ for tau_r, tau_d in [(1., 10.), (10., 100.)]:
+ for n_spk in [1, 10, 100]:
+ net = Net(tau_r, tau_d, n_spk)
+ indices = bm.as_numpy(bm.arange(1000))
+ hs, gs = bm.for_loop(net.step_run, indices, progress_bar=True)
+
+ bp.visualize.line_plot(indices * bm.get_dt(), hs, legend='h')
+ bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show)
+ plt.close('all')
+
+
+ def test_dual_expon_v2(self):
+ class Net(bp.DynSysGroup):
+ def __init__(self, tau_r, tau_d, n_spk):
+ super().__init__()
+
+ self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk))
+ self.syn = bp.dyn.DualExponV2(1, tau_rise=tau_r, tau_decay=tau_d)
+
+ def update(self):
+ return self.syn(self.inp())
+
+ for tau_r, tau_d in [(1., 10.), (5., 50.), (10., 100.)]:
+ for n_spk in [1, 10, 100]:
+ net = Net(tau_r, tau_d, n_spk)
+ indices = bm.as_numpy(bm.arange(1000))
+ gs = bm.for_loop(net.step_run, indices, progress_bar=True)
+
+ bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show)
+
+ plt.close('all')
+
+class TestAlpha(unittest.TestCase):
+
+ def test_v1(self):
+ class Net(bp.DynSysGroup):
+ def __init__(self, tau, n_spk):
+ super().__init__()
+
+ self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk))
+ self.neu = bp.dyn.LifRef(1)
+ self.proj = bp.dyn.FullProjAlignPreDS(self.inp, None,
+ bp.dyn.Alpha(1, tau_decay=tau),
+ bp.dnn.AllToAll(1, 1, 1.),
+ bp.dyn.CUBA(), self.neu)
+
+ def update(self):
+ self.inp()
+ self.proj()
+ self.neu()
+ return self.proj.syn.h.value, self.proj.syn.g.value
+
+ for tau in [10.]:
+ for n_spk in [1, 10, 50]:
+ net = Net(tau=tau, n_spk=n_spk)
+ indices = bm.as_numpy(bm.arange(1000))
+ hs, gs = bm.for_loop(net.step_run, indices, progress_bar=True)
+
+ bp.visualize.line_plot(indices * bm.get_dt(), hs, legend='h')
+ bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show)
+
+ plt.close('all')
diff --git a/brainpy/_src/dynold/neurons/reduced_models.py b/brainpy/_src/dynold/neurons/reduced_models.py
index d2bf17cc0..9615e1a53 100644
--- a/brainpy/_src/dynold/neurons/reduced_models.py
+++ b/brainpy/_src/dynold/neurons/reduced_models.py
@@ -886,7 +886,7 @@ def __init__(
self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
- def reset_state(self, batch_size=None):
+ def reset_state(self, batch_size=None, **kwargs):
super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
@@ -1023,7 +1023,7 @@ def __init__(
# parameters for training
mode: bm.Mode = None,
- spike_fun: Callable = bm.surrogate.inv_square_grad,
+ spike_fun: Callable = bm.surrogate.inv_square_grad2,
):
# initialization
super(HindmarshRose, self).__init__(size=size,
diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py
index 62b55a0e7..c7a902f01 100644
--- a/brainpy/_src/dynold/synapses/abstract_models.py
+++ b/brainpy/_src/dynold/synapses/abstract_models.py
@@ -7,6 +7,7 @@
import brainpy.math as bm
from brainpy._src.connect import TwoEndConnector, All2All, One2One
from brainpy._src.dnn import linear
+from brainpy._src.dyn import _docs
from brainpy._src.dyn import synapses
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.dynold.synouts import MgBlock, CUBA
@@ -114,7 +115,7 @@ def __init__(
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')
# register delay
- self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
+ self.pre.register_local_delay("spike", self.name, delay_step)
def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
@@ -124,7 +125,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
- pre_spike = self.pre.get_delay_data("spike", self.delay_step)
+ pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
@@ -175,32 +176,7 @@ def update(self, pre_spike=None):
class Exponential(TwoEndConn):
r"""Exponential decay synapse model.
- **Model Descriptions**
-
- The single exponential decay synapse model assumes the release of neurotransmitter,
- its diffusion across the cleft, the receptor binding, and channel opening all happen
- very quickly, so that the channels instantaneously jump from the closed to the open state.
- Therefore, its expression is given by
-
- .. math::
-
- g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau}
-
- where :math:`\tau_{delay}` is the time constant of the synaptic state decay,
- :math:`t_0` is the time of the pre-synaptic spike,
- :math:`g_{\mathrm{max}}` is the maximal conductance.
-
- Accordingly, the differential form of the exponential synapse is given by
-
- .. math::
-
- \begin{aligned}
- & g_{\mathrm{syn}}(t) = g_{max} g * \mathrm{STP} \\
- & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
- \end{aligned}
-
- where :math:`\mathrm{STP}` is used to model the short-term plasticity effect.
-
+ %s
**Model Examples**
@@ -256,12 +232,6 @@ class Exponential(TwoEndConn):
method: str
The numerical integration methods.
- References
- ----------
-
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
"""
@@ -317,7 +287,7 @@ def __init__(
self.g = self.syn.g
# delay
- self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
+ self.pre.register_local_delay("spike", self.name, delay_step)
def reset_state(self, batch_size=None):
self.syn.reset_state(batch_size)
@@ -328,7 +298,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# delays
if pre_spike is None:
- pre_spike = self.pre.get_delay_data("spike", self.delay_step)
+ pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
@@ -346,36 +316,13 @@ def update(self, pre_spike=None):
return self.output(g)
-class DualExponential(_TwoEndConnAlignPre):
- r"""Dual exponential synapse model.
-
- **Model Descriptions**
+Exponential.__doc__ = Exponential.__doc__ % (_docs.exp_syn_doc,)
- The dual exponential synapse model [1]_, also named as *difference of two exponentials* model,
- is given by:
- .. math::
-
- g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{
- \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right)
- -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right)
-
- where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2`
- is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic
- spike, :math:`g_{\mathrm{max}}` is the maximal conductance.
-
- However, in practice, this formula is hard to implement. The equivalent solution is
- two coupled linear differential equations [2]_:
-
- .. math::
-
- \begin{aligned}
- &g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g * \mathrm{STP} \\
- &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\
- &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right),
- \end{aligned}
+class DualExponential(_TwoEndConnAlignPre):
+ r"""Dual exponential synapse model.
- where :math:`\mathrm{STP}` is used to model the short-term plasticity effect of synapses.
+ %s
**Model Examples**
@@ -427,15 +374,6 @@ class DualExponential(_TwoEndConnAlignPre):
method: str
The numerical integration methods.
- References
- ----------
-
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
- .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational
- Modeling Methods for Neuroscientists.
-
"""
def __init__(
@@ -450,6 +388,7 @@ def __init__(
tau_decay: Union[float, ArrayType] = 10.0,
tau_rise: Union[float, ArrayType] = 1.,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
+ A: Optional[Union[float, ArrayType, Callable]] = None,
method: str = 'exp_auto',
# other parameters
@@ -472,6 +411,7 @@ def __init__(
syn = synapses.DualExpon(pre.size,
pre.keep_size,
+ A=A,
mode=mode,
tau_decay=tau_decay,
tau_rise=tau_rise,
@@ -498,27 +438,13 @@ def update(self, pre_spike=None):
return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
-class Alpha(DualExponential):
- r"""Alpha synapse model.
+DualExponential.__doc__ = DualExponential.__doc__ % (_docs.dual_exp_syn_doc,)
- **Model Descriptions**
-
- The analytical expression of alpha synapse is given by:
-
- .. math::
- g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right).
-
- While, this equation is hard to implement. So, let's try to convert it into the
- differential forms:
-
- .. math::
+class Alpha(_TwoEndConnAlignPre):
+ r"""Alpha synapse model.
- \begin{aligned}
- &g_{\mathrm{syn}}(t)= g_{\mathrm{max}} g \\
- &\frac{d g}{d t}=-\frac{g}{\tau}+h \\
- &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right)
- \end{aligned}
+ %s
**Model Examples**
@@ -567,12 +493,6 @@ class Alpha(DualExponential):
method: str
The numerical integration methods.
- References
- ----------
-
- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
- "The Synapse." Principles of Computational Modelling in Neuroscience.
- Cambridge: Cambridge UP, 2011. 172-95. Print.
"""
def __init__(
@@ -593,20 +513,42 @@ def __init__(
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
- super(Alpha, self).__init__(pre=pre,
- post=post,
- conn=conn,
- comp_method=comp_method,
- delay_step=delay_step,
- g_max=g_max,
- tau_decay=tau_decay,
- tau_rise=tau_decay,
- method=method,
- output=output,
- stp=stp,
- name=name,
- mode=mode,
- stop_spike_gradient=stop_spike_gradient)
+ # parameters
+ self.stop_spike_gradient = stop_spike_gradient
+ self.comp_method = comp_method
+ self.tau_decay = tau_decay
+ if bm.size(self.tau_decay) != 1:
+ raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
+ f'But we got {self.tau_decay}')
+
+ syn = synapses.Alpha(pre.size,
+ pre.keep_size,
+ mode=mode,
+ tau_decay=tau_decay,
+ method=method)
+
+ super().__init__(pre=pre,
+ post=post,
+ syn=syn,
+ conn=conn,
+ comp_method=comp_method,
+ delay_step=delay_step,
+ g_max=g_max,
+ output=output,
+ stp=stp,
+ name=name,
+ mode=mode, )
+
+ self.check_post_attrs('input')
+ # copy the references
+ self.g = syn.g
+ self.h = syn.h
+
+ def update(self, pre_spike=None):
+ return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
+
+
+Alpha.__doc__ = Alpha.__doc__ % (_docs.alpha_syn_doc,)
class NMDA(_TwoEndConnAlignPre):
diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py
index 02a0355aa..55bac7111 100644
--- a/brainpy/_src/dynold/synapses/base.py
+++ b/brainpy/_src/dynold/synapses/base.py
@@ -6,7 +6,7 @@
from brainpy import math as bm
from brainpy._src.connect import TwoEndConnector, One2One, All2All
from brainpy._src.dnn import linear
-from brainpy._src.dyn import projections
+from brainpy._src.dyn.projections.conn import SynConn
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import parameter
@@ -29,7 +29,7 @@ class _SynapseComponent(DynamicalSystem):
synaptic long-term plasticity, and others. """
'''Master of this component.'''
- master: projections.SynConn
+ master: SynConn
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -50,9 +50,9 @@ def isregistered(self, val: bool):
def reset_state(self, batch_size=None):
pass
- def register_master(self, master: projections.SynConn):
- if not isinstance(master, projections.SynConn):
- raise TypeError(f'master must be instance of {projections.SynConn.__name__}, but we got {type(master)}')
+ def register_master(self, master: SynConn):
+ if not isinstance(master, SynConn):
+ raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}')
if self.isregistered:
raise ValueError(f'master has been registered, but we got another master going to be registered.')
if hasattr(self, 'master') and self.master != master:
@@ -90,7 +90,7 @@ def __init__(
f'But we got {type(target_var)}')
self.target_var: Optional[bm.Variable] = target_var
- def register_master(self, master: projections.SynConn):
+ def register_master(self, master: SynConn):
super().register_master(master)
# initialize target variable to output
@@ -125,7 +125,7 @@ def clone(self):
return _NullSynOut()
-class TwoEndConn(projections.SynConn):
+class TwoEndConn(SynConn):
"""Base class to model synaptic connections.
Parameters
@@ -296,7 +296,7 @@ def __init__(
mode=mode)
# delay
- self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
+ self.pre.register_local_delay("spike", self.name, delay_step)
# synaptic dynamics
self.syn = syn
@@ -317,7 +317,7 @@ def __init__(
def update(self, pre_spike=None, stop_spike_gradient: bool = False):
if pre_spike is None:
- pre_spike = self.pre.get_delay_data("spike", self.delay_step)
+ pre_spike = self.pre.get_local_delay("spike", self.name)
if stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
if self.stp is not None:
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 00120a666..cb086b10d 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -2,8 +2,8 @@
import collections
import inspect
-import warnings
import numbers
+import warnings
from typing import Union, Dict, Callable, Sequence, Optional, Any
import numpy as np
@@ -13,7 +13,7 @@
from brainpy._src.deprecations import _update_deprecate_msg
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, _get_delay_tool
-from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError
+from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
__all__ = [
@@ -27,9 +27,11 @@
'Dynamic', 'Projection',
]
-
IonChaDyn = None
SLICE_VARS = 'slice_vars'
+the_top_layer_reset_state = True
+clear_input = None
+reset_state = None
def not_implemented(fun):
@@ -91,7 +93,8 @@ def __init__(
# Attribute for "SupportInputProj"
# each instance of "SupportInputProj" should have a "cur_inputs" attribute
- self.cur_inputs = bm.node_dict()
+ self.current_inputs = bm.node_dict()
+ self.delta_inputs = bm.node_dict()
# the before- / after-updates used for computing
# added after the version of 2.4.3
@@ -138,20 +141,19 @@ def update(self, *args, **kwargs):
"""
raise NotImplementedError('Must implement "update" function by subclass self.')
- def reset(self, *args, include_self: bool = False, **kwargs):
+ def reset(self, *args, **kwargs):
"""Reset function which reset the whole variables in the model (including its children models).
``reset()`` function is a collective behavior which resets all states in this model.
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
-
- Args::
- include_self: bool. Reset states including the node self. Please turn on this if the node has
- implemented its ".reset_state()" function.
"""
- from brainpy._src.helpers import reset_state
+ global reset_state
+ if reset_state is None:
+ from brainpy._src.helpers import reset_state
reset_state(self, *args, **kwargs)
+ @not_implemented
def reset_state(self, *args, **kwargs):
"""Reset function which resets local states in this model.
@@ -162,19 +164,6 @@ def reset_state(self, *args, **kwargs):
"""
pass
- # raise APIChangedError(
- # '''
- # From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
- #
- # 1. If you are resetting all states in a network by calling "net.reset_state()", please use
- # "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states
- # in a local node (excluded its children nodes).
- #
- # 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
- #
- # '''
- # )
-
def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
pass
@@ -193,8 +182,13 @@ def step_run(self, i, *args, **kwargs):
Returns:
out: The update function returns.
"""
+ global clear_input
+ if clear_input is None:
+ from brainpy._src.helpers import clear_input
share.save(i=i, t=i * bm.dt)
- return self.update(*args, **kwargs)
+ out = self.update(*args, **kwargs)
+ clear_input(self)
+ return out
@bm.cls_jit(inline=True)
def jit_step_run(self, i, *args, **kwargs):
@@ -344,14 +338,40 @@ def _compatible_update(self, *args, **kwargs):
return ret
return update_fun(*args, **kwargs)
+ def _compatible_reset_state(self, *args, **kwargs):
+ global the_top_layer_reset_state
+ the_top_layer_reset_state = False
+ try:
+ if hasattr(self.reset_state, '_not_implemented'):
+ self.reset(*args, **kwargs)
+ warnings.warn(
+ '''
+ From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_saving_and_loading.html for details.
+
+ 1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use
+ "bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)".
+ ".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes).
+
+ 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
+
+ ''',
+ DeprecationWarning
+ )
+ else:
+ self.reset_state(*args, **kwargs)
+ finally:
+ the_top_layer_reset_state = True
+
def _get_update_fun(self):
return object.__getattribute__(self, 'update')
def __getattribute__(self, item):
if item == 'update':
return self._compatible_update # update function compatible with previous ``update()`` function
- else:
- return super().__getattribute__(item)
+ if item == 'reset_state':
+ if the_top_layer_reset_state:
+ return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function
+ return super().__getattribute__(item)
def __repr__(self):
return f'{self.name}(mode={self.mode})'
diff --git a/brainpy/_src/helpers.py b/brainpy/_src/helpers.py
index 9352ff850..ab0a306e9 100644
--- a/brainpy/_src/helpers.py
+++ b/brainpy/_src/helpers.py
@@ -1,11 +1,12 @@
-from typing import Dict
+from typing import Dict, Callable
+from brainpy._src import dynsys
from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.dynsys import DynamicalSystem, DynView
from brainpy._src.math.object_transform.base import StateLoadResult
-
__all__ = [
+ 'reset_level',
'reset_state',
'load_state',
'save_state',
@@ -13,6 +14,34 @@
]
+_max_level = 10
+
+
+def reset_level(level: int = 0):
+ """The decorator for indicating the resetting level.
+
+ The function takes an optional integer argument level with a default value of 0.
+
+ The lower the level, the earlier the function is called.
+
+ >>> import brainpy as bp
+ >>> bp.reset_level(0)
+ >>> bp.reset_level(-1)
+ >>> bp.reset_level(-2)
+
+ """
+ if level < 0:
+ level = _max_level + level
+ if level < 0 or level >= _max_level:
+ raise ValueError(f'"reset_level" must be an integer in [0, 10). but we got {level}')
+
+ def wrap(fun: Callable):
+ fun.reset_level = level
+ return fun
+
+ return wrap
+
+
def reset_state(target: DynamicalSystem, *args, **kwargs):
"""Reset states of all children nodes in the given target.
@@ -20,11 +49,28 @@ def reset_state(target: DynamicalSystem, *args, **kwargs):
Args:
target: The target DynamicalSystem.
- *args:
- **kwargs:
"""
- for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values():
- node.reset_state(*args, **kwargs)
+ dynsys.the_top_layer_reset_state = False
+
+ try:
+ nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values())
+ nodes_with_level = []
+
+ # reset node whose `reset_state` has no `reset_level`
+ for node in nodes:
+ if not hasattr(node.reset_state, 'reset_level'):
+ node.reset_state(*args, **kwargs)
+ else:
+ nodes_with_level.append(node)
+
+ # reset the node's states
+ for l in range(_max_level):
+ for node in nodes_with_level:
+ if node.reset_state.reset_level == l:
+ node.reset_state(*args, **kwargs)
+
+ finally:
+ dynsys.the_top_layer_reset_state = True
def clear_input(target: DynamicalSystem, *args, **kwargs):
diff --git a/brainpy/_src/initialize/random_inits.py b/brainpy/_src/initialize/random_inits.py
index 893ed06b1..fbad02dd9 100644
--- a/brainpy/_src/initialize/random_inits.py
+++ b/brainpy/_src/initialize/random_inits.py
@@ -11,6 +11,7 @@
__all__ = [
'Normal',
+ 'TruncatedNormal',
'Uniform',
'VarianceScaling',
'KaimingUniform',
@@ -82,7 +83,7 @@ def _format_shape(shape):
if len(shape) == 0:
raise ValueError('Please provide shape.')
if len(shape) == 1:
- if isinstance(shape, (tuple, list)):
+ if isinstance(shape[0], (tuple, list)):
return shape[0]
else:
return shape
@@ -122,6 +123,50 @@ def __repr__(self):
return f'{self.__class__.__name__}(scale={self.scale}, rng={self.rng})'
+class TruncatedNormal(_InterLayerInitializer):
+ """Initialize weights with truncated normal distribution.
+
+ Parameters
+ ----------
+ loc : float, ndarray
+ Mean ("centre") of the distribution before truncating. Note that
+ the mean of the truncated distribution will not be exactly equal
+ to ``loc``.
+ scale : float
+ The standard deviation of the normal distribution before truncating.
+ lower : float, ndarray
+ A float or array of floats representing the lower bound for
+ truncation. Must be broadcast-compatible with ``upper``.
+ upper : float, ndarray
+ A float or array of floats representing the upper bound for
+ truncation. Must be broadcast-compatible with ``lower``.
+
+ """
+
+ def __init__(self, loc=0., scale=1., lower=None, upper=None, seed=None):
+ super(TruncatedNormal, self).__init__()
+ assert scale > 0, '`scale` must be positive.'
+ self.scale = scale
+ self.loc = loc
+ self.lower = lower
+ self.upper = upper
+ self.rng = bm.random.default_rng(seed, clone=False)
+
+ def __call__(self, shape, dtype=None):
+ shape = _format_shape(shape)
+ weights = self.rng.truncated_normal(
+ size=shape,
+ scale=self.scale,
+ lower=self.lower,
+ upper=self.upper,
+ loc=self.loc
+ )
+ return bm.asarray(weights, dtype=dtype)
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(loc={self.loc}, scale={self.scale}, lower={self.lower}, upper={self.upper}, rng={self.rng})'
+
+
class Gamma(_InterLayerInitializer):
"""Initialize weights with Gamma distribution.
@@ -227,7 +272,7 @@ def __call__(self, shape, dtype=None):
variance = (self.scale / denominator).astype(dtype)
if self.distribution == "truncated_normal":
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
- res = self.rng.truncated_normal(-2, 2, shape, dtype) * stddev
+ res = self.rng.truncated_normal(-2, 2, shape).astype(dtype) * stddev
elif self.distribution == "normal":
res = self.rng.randn(*shape) * jnp.sqrt(variance).astype(dtype)
elif self.distribution == "uniform":
diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py
index 2e577e6ab..e44e324e7 100644
--- a/brainpy/_src/integrators/ode/exponential.py
+++ b/brainpy/_src/integrators/ode/exponential.py
@@ -105,8 +105,6 @@
.. [2] Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286.
"""
-import logging
-
from functools import wraps
from brainpy import errors
from brainpy._src import math as bm
@@ -360,9 +358,7 @@ def integral(*args, **kwargs):
assert len(args) > 0
dt = kwargs.pop(C.DT, self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
- phi = bm.where(linear == 0.,
- bm.ones_like(linear),
- (bm.exp(dt * linear) - 1) / (dt * linear))
+ phi = bm.exprel(dt * linear)
return args[0] + dt * phi * derivative
return [(integral, vars, pars), ]
diff --git a/brainpy/_src/integrators/sde/normal.py b/brainpy/_src/integrators/sde/normal.py
index b7de12515..34dbafff1 100644
--- a/brainpy/_src/integrators/sde/normal.py
+++ b/brainpy/_src/integrators/sde/normal.py
@@ -626,8 +626,7 @@ def integral(*args, **kwargs):
assert len(args) > 0
dt = kwargs.pop('dt', self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
- linear = bm.as_jax(linear)
- phi = jnp.where(linear == 0., jnp.ones_like(linear), (jnp.exp(dt * linear) - 1) / (dt * linear))
+ phi = bm.as_jax(bm.exprel(dt * linear))
return args[0] + dt * phi * derivative
return [(integral, vars, pars), ]
diff --git a/brainpy/_src/losses/comparison.py b/brainpy/_src/losses/comparison.py
index 8d8fb1388..ad0c3ea35 100644
--- a/brainpy/_src/losses/comparison.py
+++ b/brainpy/_src/losses/comparison.py
@@ -39,6 +39,7 @@
'log_cosh_loss',
'ctc_loss_with_forward_probs',
'ctc_loss',
+ 'multi_margin_loss',
]
@@ -1050,3 +1051,47 @@ def ctc_loss(logits: ArrayType,
logits, logit_paddings, labels, label_paddings,
blank_id=blank_id, log_epsilon=log_epsilon)
return per_seq_loss
+
+
+def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'):
+ r"""Computes multi-class margin loss, also called multi-class hinge loss.
+
+ This loss function is often used in multi-class classification problems.
+ It is a type of hinge loss that tries to ensure the correct class score is greater than the scores of other classes by a margin.
+
+ The loss function for sample :math:`i` is:
+
+ .. math::
+ \ell(x, y) = \sum_{j \neq y_i} \max(0, x_{y_j} - x_{y_i} + \text{margin})
+
+ where :math:`x` is the input, :math:`y` is the target, and :math:`y_i` is the index of the true class,
+ and :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`
+ and :math:`i \neq y`.
+
+ Args:
+ predicts: :math:`(N, C)` where `C = number of classes`.
+ target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`.
+ margin (float, optional): Has a default value of :math:`1`.
+ p (float, optional): Has a default value of :math:`1`.
+ reduction (str, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
+ be applied, ``'mean'``: the sum of the output will be divided by the
+ number of elements in the output, ``'sum'``: the output will be summed.
+ Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
+ and in the meantime, specifying either of those two args will override :attr:`reduction`.
+ Default: ``'mean'``
+
+ Returns:
+ a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`.
+ """
+ assert p == 1 or p == 2, 'p should be 1 or 2'
+ batch_size = predicts.shape[0]
+ correct_scores = predicts[jnp.arange(batch_size), targets]
+ margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p)
+ margins = margins.at[jnp.arange(batch_size), targets].set(0)
+ if reduction == 'mean':
+ return jnp.sum(margins) / batch_size
+ elif reduction == 'sum':
+ return jnp.sum(margins)
+ elif reduction == 'none':
+ return margins
diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py
index 3128c5e67..3102bc1d0 100644
--- a/brainpy/_src/math/__init__.py
+++ b/brainpy/_src/math/__init__.py
@@ -44,7 +44,7 @@
from .compat_numpy import *
from .compat_tensorflow import *
from .others import *
-from . import random, linalg, fft
+from . import random, linalg, fft, tifunc
# operators
from .op_register import *
diff --git a/brainpy/_src/math/activations.py b/brainpy/_src/math/activations.py
index 60c7991f1..54ced5d4d 100644
--- a/brainpy/_src/math/activations.py
+++ b/brainpy/_src/math/activations.py
@@ -298,7 +298,7 @@ def leaky_relu(x, negative_slope=1e-2):
return jnp.where(x >= 0, x, negative_slope * x)
-def softplus(x, beta=1, threshold=20):
+def softplus(x, beta: float = 1., threshold: float = 20.):
r"""Softplus activation function.
Computes the element-wise function
@@ -315,12 +315,12 @@ def softplus(x, beta=1, threshold=20):
Parameters
----------
x: The input array.
- beta: the :math:`\beta` value for the Softplus formulation. Default: 1
- threshold: values above this revert to a linear function. Default: 20
+ beta: the :math:`\beta` value for the Softplus formulation. Default: 1.
+ threshold: values above this revert to a linear function. Default: 20.
"""
x = x.value if isinstance(x, Array) else x
- return jnp.where(x > threshold, x * beta, 1 / beta * jnp.logaddexp(beta * x, 0))
+ return jnp.where(x > threshold / beta, x, 1 / beta * jnp.logaddexp(beta * x, 0))
def log_sigmoid(x):
diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py
index a5ffc2984..213185df1 100644
--- a/brainpy/_src/math/compat_numpy.py
+++ b/brainpy/_src/math/compat_numpy.py
@@ -205,6 +205,23 @@ def asfarray(a, dtype=np.float_):
dtype = np.float_
return asarray(a, dtype=dtype)
+def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
+ del assume_unique
+ ar1_flat = ravel(ar1)
+ ar2_flat = ravel(ar2)
+ # Note: an algorithm based on searchsorted has better scaling, but in practice
+ # is very slow on accelerators because it relies on lax control flow. If XLA
+ # ever supports binary search natively, we should switch to this:
+ # ar2_flat = jnp.sort(ar2_flat)
+ # ind = jnp.searchsorted(ar2_flat, ar1_flat)
+ # if invert:
+ # return ar1_flat != ar2_flat[ind]
+ # else:
+ # return ar1_flat == ar2_flat[ind]
+ if invert:
+ return asarray((ar1_flat[:, None] != ar2_flat[None, :]).all(-1))
+ else:
+ return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1))
# Others
# ------
@@ -237,7 +254,6 @@ def asfarray(a, dtype=np.float_):
histogram_bin_edges = _compatible_with_brainpy_array(jnp.histogram_bin_edges)
histogramdd = _compatible_with_brainpy_array(jnp.histogramdd)
i0 = _compatible_with_brainpy_array(jnp.i0)
-in1d = _compatible_with_brainpy_array(jnp.in1d)
indices = _compatible_with_brainpy_array(jnp.indices)
insert = _compatible_with_brainpy_array(jnp.insert)
intersect1d = _compatible_with_brainpy_array(jnp.intersect1d)
diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py
index 86695e440..192eb6709 100644
--- a/brainpy/_src/math/compat_pytorch.py
+++ b/brainpy/_src/math/compat_pytorch.py
@@ -1,17 +1,16 @@
-from typing import Union, Optional
+from typing import Union, Optional, Sequence
import jax
import jax.numpy as jnp
import numpy as np
+from .compat_numpy import (concatenate, minimum, maximum, )
from .ndarray import Array, _as_jax_array_, _return, _check_out
-from .compat_numpy import (
- concatenate, shape, minimum, maximum,
-)
__all__ = [
'Tensor',
'flatten',
+ 'unflatten',
'cat',
'abs',
'absolute',
@@ -85,31 +84,62 @@ def flatten(input: Union[jax.Array, Array],
return jnp.reshape(input, new_shape)
-def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array:
+def unflatten(x: Union[jax.Array, Array], dim: int, sizes: Sequence[int]) -> Array:
+ """
+ Expands a dimension of the input tensor over multiple dimensions.
+
+ Args:
+ x: input tensor.
+ dim: Dimension to be unflattened, specified as an index into ``x.shape``.
+ sizes: New shape of the unflattened dimension. One of its elements can be -1
+ in which case the corresponding output dimension is inferred.
+ Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``.
+
+ Returns:
+ A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions.
+ The returned tensor has one more dimension than the input tensor.
+ The returned tensor shares the same underlying data with this tensor.
+ """
+ assert x.ndim > dim, ('The dimension to be unflattened should be less than the tensor dimension. '
+ f'Got {dim} and {x.ndim}.')
+ x = _as_jax_array_(x)
+ shape = x.shape
+ new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:]
+ r = jnp.reshape(x, new_shape)
+ return _return(r)
+
+
+def unsqueeze(x: Union[jax.Array, Array], dim: int) -> Array:
"""Returns a new tensor with a dimension of size one inserted at the specified position.
-The returned tensor shares the same underlying data with this tensor.
-A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used.
-Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.
-Parameters
-----------
-input: Array
- The input Array
-dim: int
- The index at which to insert the singleton dimension
-
-Returns
--------
-out: Array
-"""
- input = _as_jax_array_(input)
- return Array(jnp.expand_dims(input, dim))
+
+ The returned tensor shares the same underlying data with this tensor.
+ A dim value within the range ``[-input.dim() - 1, input.dim() + 1)`` can be used.
+ Negative dim will correspond to unsqueeze() applied at ``dim = dim + input.dim() + 1``.
+
+ Parameters
+ ----------
+ x: Array
+ The input Array
+ dim: int
+ The index at which to insert the singleton dimension
+
+ Returns
+ -------
+ out: Array
+ """
+ x = _as_jax_array_(x)
+ r = jnp.expand_dims(x, dim)
+ return _return(r)
# Math operations
-def abs(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.abs(input)
+def abs(
+ x: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.abs(x)
if out is None:
return _return(r)
else:
@@ -120,10 +150,13 @@ def abs(input: Union[jax.Array, Array],
absolute = abs
-def acos(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.arccos(input)
+def acos(
+ x: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.arccos(x)
if out is None:
return _return(r)
else:
@@ -134,10 +167,13 @@ def acos(input: Union[jax.Array, Array],
arccos = acos
-def acosh(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.arccosh(input)
+def acosh(
+ x: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.arccosh(x)
if out is None:
return _return(r)
else:
@@ -148,14 +184,25 @@ def acosh(input: Union[jax.Array, Array],
arccosh = acosh
-def add(input: Union[jax.Array, Array, jnp.number],
- other: Union[jax.Array, Array, jnp.number],
- *, alpha: Optional[jnp.number] = 1,
- out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- other = _as_jax_array_(other)
- other = jnp.multiply(alpha, other)
- r = jnp.add(input, other)
+def add(
+ x: Union[jax.Array, Array, jnp.number],
+ y: Union[jax.Array, Array, jnp.number],
+ *,
+ alpha: Optional[jnp.number] = 1,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ r"""
+ Adds ``other``, scaled by ``alpha``, to ``input``.
+
+ .. math::
+
+ \text { out }_i=\text { input }_i+\text { alpha } \times \text { other }_i
+
+ """
+ x = _as_jax_array_(x)
+ y = _as_jax_array_(y)
+ y = jnp.multiply(alpha, y)
+ r = jnp.add(x, y)
if out is None:
return _return(r)
else:
@@ -163,32 +210,41 @@ def add(input: Union[jax.Array, Array, jnp.number],
out.value = r
-def addcdiv(input: Union[jax.Array, Array, jnp.number],
- tensor1: Union[jax.Array, Array, jnp.number],
- tensor2: Union[jax.Array, Array, jnp.number],
- *, value: jnp.number = 1,
- out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
+def addcdiv(
+ x: Union[jax.Array, Array, jnp.number],
+ tensor1: Union[jax.Array, Array, jnp.number],
+ tensor2: Union[jax.Array, Array, jnp.number],
+ *,
+ value: jnp.number = 1,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
tensor1 = _as_jax_array_(tensor1)
tensor2 = _as_jax_array_(tensor2)
other = jnp.divide(tensor1, tensor2)
- return add(input, other, alpha=value, out=out)
+ return add(x, other, alpha=value, out=out)
-def addcmul(input: Union[jax.Array, Array, jnp.number],
- tensor1: Union[jax.Array, Array, jnp.number],
- tensor2: Union[jax.Array, Array, jnp.number],
- *, value: jnp.number = 1,
- out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
+def addcmul(
+ x: Union[jax.Array, Array, jnp.number],
+ tensor1: Union[jax.Array, Array, jnp.number],
+ tensor2: Union[jax.Array, Array, jnp.number],
+ *,
+ value: jnp.number = 1,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
tensor1 = _as_jax_array_(tensor1)
tensor2 = _as_jax_array_(tensor2)
other = jnp.multiply(tensor1, tensor2)
- return add(input, other, alpha=value, out=out)
+ return add(x, other, alpha=value, out=out)
-def angle(input: Union[jax.Array, Array, jnp.number],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.angle(input)
+def angle(
+ x: Union[jax.Array, Array, jnp.number],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.angle(x)
if out is None:
return _return(r)
else:
@@ -196,10 +252,13 @@ def angle(input: Union[jax.Array, Array, jnp.number],
out.value = r
-def asin(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.arcsin(input)
+def asin(
+ x: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.arcsin(x)
if out is None:
return _return(r)
else:
@@ -210,10 +269,13 @@ def asin(input: Union[jax.Array, Array],
arcsin = asin
-def asinh(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.arcsinh(input)
+def asinh(
+ x: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.arcsinh(x)
if out is None:
return _return(r)
else:
@@ -224,10 +286,13 @@ def asinh(input: Union[jax.Array, Array],
arcsinh = asinh
-def atan(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.arctan(input)
+def atan(
+ x: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.arctan(x)
if out is None:
return _return(r)
else:
@@ -238,10 +303,13 @@ def atan(input: Union[jax.Array, Array],
arctan = atan
-def atanh(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.arctanh(input)
+def atanh(
+ x: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x = _as_jax_array_(x)
+ r = jnp.arctanh(x)
if out is None:
return _return(r)
else:
@@ -252,10 +320,15 @@ def atanh(input: Union[jax.Array, Array],
arctanh = atanh
-def atan2(input: Union[jax.Array, Array],
- *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
- input = _as_jax_array_(input)
- r = jnp.arctan2(input)
+def atan2(
+ x1: Union[jax.Array, Array],
+ x2: Union[jax.Array, Array],
+ *,
+ out: Optional[Union[Array, jax.Array, np.ndarray]] = None
+) -> Optional[Array]:
+ x1 = _as_jax_array_(x1)
+ x2 = _as_jax_array_(x2)
+ r = jnp.arctan2(x1, x2)
if out is None:
return _return(r)
else:
diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py
new file mode 100644
index 000000000..19aca92cf
--- /dev/null
+++ b/brainpy/_src/math/defaults.py
@@ -0,0 +1,38 @@
+import jax.numpy as jnp
+from jax import config
+
+from brainpy._src.dependency_check import import_taichi
+from .modes import NonBatchingMode
+from .scales import IdScaling
+
+__all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_']
+
+ti = import_taichi()
+
+# Default computation mode.
+mode = NonBatchingMode()
+
+# '''Default computation mode.'''
+membrane_scaling = IdScaling()
+
+# '''Default time step.'''
+dt = 0.1
+
+# '''Default bool data type.'''
+bool_ = jnp.bool_
+
+# '''Default integer data type.'''
+int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32
+
+# '''Default integer data type in Taichi.'''
+ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32
+
+# '''Default float data type.'''
+float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32
+
+# '''Default float data type in Taichi.'''
+ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32
+
+# '''Default complex data type.'''
+complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64
+
diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py
index eb8e27c8f..676e4286b 100644
--- a/brainpy/_src/math/delayvars.py
+++ b/brainpy/_src/math/delayvars.py
@@ -11,7 +11,7 @@
from brainpy import check
from brainpy.check import is_float, is_integer, jit_error
from brainpy.errors import UnsupportedError
-from .compat_numpy import vstack, broadcast_to
+from .compat_numpy import broadcast_to, expand_dims, concatenate
from .environment import get_dt, get_float
from .interoperability import as_jax
from .ndarray import ndarray, Array
@@ -392,6 +392,7 @@ def reset(
dtype=delay_target.dtype),
batch_axis=batch_axis)
else:
+ self.data.value
self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape,
dtype=delay_target.dtype)
@@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None):
elif self.update_method == CONCAT_UPDATE:
if self.num_delay_step >= 2:
- self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]])
+ self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0)
else:
self.data[:] = value
diff --git a/brainpy/_src/math/einops.py b/brainpy/_src/math/einops.py
new file mode 100644
index 000000000..d42026974
--- /dev/null
+++ b/brainpy/_src/math/einops.py
@@ -0,0 +1,728 @@
+import functools
+import itertools
+from collections import OrderedDict
+from typing import Set, Tuple, List, Dict, Union, Callable, Optional, cast
+
+import jax
+import numpy as np
+
+from . import compat_numpy as bnp
+from . import others as bnp2
+from .einops_parsing import ParsedExpression, _ellipsis, AnonymousAxis, EinopsError
+from .ndarray import Array
+
+__all__ = [
+ 'ein_reduce', 'ein_rearrange', 'ein_repeat', 'ein_shape',
+]
+
+Tensor = Union[Array, jax.Array]
+ReductionCallable = Callable[[Tensor, Tuple[int, ...]], Tensor]
+Reduction = Union[str, ReductionCallable]
+
+_reductions = ("min", "max", "sum", "mean", "prod", "any", "all")
+
+# magic integers are required to stay within
+# traceable subset of language
+_unknown_axis_length = -999999
+_expected_axis_length = -99999
+
+
+def _product(sequence: List[int]) -> int:
+ """minimalistic product that works both with numbers and symbols. Supports empty lists"""
+ result = 1
+ for element in sequence:
+ result *= element
+ return result
+
+
+def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int]):
+ if callable(reduction_type):
+ # custom callable
+ return reduction_type(tensor, tuple(reduced_axes))
+ else:
+ # one of built-in operations
+ assert reduction_type in _reductions
+ if reduction_type == "mean":
+ if not bnp2.is_float_type(tensor):
+ raise NotImplementedError("reduce_mean is not available for non-floating tensors")
+ return __reduce(tensor, reduction_type, tuple(reduced_axes))
+
+
+def __reduce(x: Union[Array, jax.Array], operation: str, reduced_axes):
+ if operation == "min":
+ return x.min(axis=reduced_axes)
+ elif operation == "max":
+ return x.max(axis=reduced_axes)
+ elif operation == "sum":
+ return x.sum(axis=reduced_axes)
+ elif operation == "mean":
+ return x.mean(axis=reduced_axes)
+ elif operation == "prod":
+ return x.prod(axis=reduced_axes)
+ elif operation == "any":
+ return x.any(axis=reduced_axes)
+ elif operation == "all":
+ return x.all(axis=reduced_axes)
+ else:
+ raise NotImplementedError("Unknown reduction ", operation)
+
+
+def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
+ # 'collapses' neighboring axes if those participate in the result pattern in the same order
+ # TODO add support for added_axes
+ assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
+ # joining consecutive axes that will be reduced
+ # possibly we can skip this if all backends can optimize this (not sure)
+ reduced_axes = tuple(sorted(reduced_axes))
+ for i in range(len(reduced_axes) - 1)[::-1]:
+ if reduced_axes[i] + 1 == reduced_axes[i + 1]:
+ removed_axis = reduced_axes[i + 1]
+ removed_length = init_shapes[removed_axis]
+ init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
+ init_shapes[removed_axis - 1] *= removed_length
+ reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:])
+
+ # removing axes that are moved together during reshape
+ def build_mapping():
+ init_to_final = {}
+ for axis in range(len(init_shapes)):
+ if axis in reduced_axes:
+ init_to_final[axis] = None
+ else:
+ after_reduction = sum(x is not None for x in init_to_final.values())
+ init_to_final[axis] = list(axes_reordering).index(after_reduction)
+ return init_to_final
+
+ init_axis_to_final_axis = build_mapping()
+
+ for init_axis in range(len(init_shapes) - 1)[::-1]:
+ if init_axis_to_final_axis[init_axis] is None:
+ continue
+ if init_axis_to_final_axis[init_axis + 1] is None:
+ continue
+ if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
+ removed_axis = init_axis + 1
+ removed_length = init_shapes[removed_axis]
+ removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
+
+ reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
+ init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
+ init_shapes[removed_axis - 1] *= removed_length
+ old_reordering = axes_reordering
+ axes_reordering = []
+ for axis in old_reordering:
+ if axis == removed_axis_after_reduction:
+ pass
+ elif axis < removed_axis_after_reduction:
+ axes_reordering.append(axis)
+ else:
+ axes_reordering.append(axis - 1)
+ init_axis_to_final_axis = build_mapping()
+
+ return init_shapes, reduced_axes, axes_reordering, final_shapes
+
+
+CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int]
+
+# Actual type is tuple[tuple[str, int], ...]
+# However torch.jit.script does not "understand" the correct type,
+# and torch_specific will use list version.
+HashableAxesLengths = Tuple[Tuple[str, int], ...]
+FakeHashableAxesLengths = List[Tuple[str, int]]
+
+
+class TransformRecipe:
+ """
+ Recipe describes actual computation pathway.
+ Recipe can be applied to a tensor or variable.
+ """
+
+ # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)
+ # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided
+
+ def __init__(
+ self,
+ # list of sizes (or just sizes) for elementary axes as they appear in left expression.
+ # this is what (after computing unknown parts) will be a shape after first transposition.
+ # This does not include any ellipsis dimensions.
+ elementary_axes_lengths: List[int],
+ # if additional axes are provided, they should be set in prev array
+ # This shows mapping from name to position
+ axis_name2elementary_axis: Dict[str, int],
+ # each dimension in input can help to reconstruct length of one elementary axis
+ # or verify one of dimensions. Each element points to element of elementary_axes_lengths.
+ input_composition_known_unknown: List[Tuple[List[int], List[int]]],
+ # permutation applied to elementary axes, if ellipsis is absent
+ axes_permutation: List[int],
+ # permutation puts reduced axes in the end, we only need to know the first position.
+ first_reduced_axis: int,
+ # at which positions which of elementary axes should appear. Axis position -> axis index.
+ added_axes: Dict[int, int],
+ # ids of axes as they appear in result, again pointers to elementary_axes_lengths,
+ # only used to infer result dimensions
+ output_composite_axes: List[List[int]],
+ ):
+ self.elementary_axes_lengths: List[int] = elementary_axes_lengths
+ self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis
+ self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown
+ self.axes_permutation: List[int] = axes_permutation
+
+ self.first_reduced_axis: int = first_reduced_axis
+ self.added_axes: Dict[int, int] = added_axes
+ self.output_composite_axes: List[List[int]] = output_composite_axes
+
+
+def _reconstruct_from_shape_uncached(
+ self: TransformRecipe, shape: List[int], axes_dims: FakeHashableAxesLengths
+) -> CookedRecipe:
+ """
+ Reconstruct all actual parameters using shape.
+ Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet)
+ known axes can be integers or symbols, but not Nones.
+ """
+ # magic number
+ need_init_reshape = False
+
+ # last axis is allocated for collapsed ellipsis
+ axes_lengths: List[int] = list(self.elementary_axes_lengths)
+ for axis, dim in axes_dims:
+ axes_lengths[self.axis_name2elementary_axis[axis]] = dim
+
+ for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown):
+ length = shape[input_axis]
+ if len(known_axes) == 0 and len(unknown_axes) == 1:
+ # shortcut for the most common case
+ axes_lengths[unknown_axes[0]] = length
+ continue
+
+ known_product = 1
+ for axis in known_axes:
+ known_product *= axes_lengths[axis]
+
+ if len(unknown_axes) == 0:
+ if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
+ raise EinopsError(f"Shape mismatch, {length} != {known_product}")
+ else:
+ # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out'
+ if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
+ raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}")
+
+ unknown_axis = unknown_axes[0]
+ inferred_length: int = length // known_product
+ axes_lengths[unknown_axis] = inferred_length
+
+ if len(known_axes) + len(unknown_axes) != 1:
+ need_init_reshape = True
+
+ # at this point all axes_lengths are computed (either have values or variables, but not Nones)
+
+ # elementary axes are ordered as they appear in input, then all added axes
+ init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None
+
+ need_final_reshape = False
+ final_shapes: List[int] = []
+ for grouping in self.output_composite_axes:
+ lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
+ final_shapes.append(_product(lengths))
+ if len(lengths) != 1:
+ need_final_reshape = True
+
+ added_axes: Dict[int, int] = {
+ pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()
+ }
+
+ # this list can be empty
+ reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation)))
+
+ n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation)
+
+ axes_reordering: Optional[List[int]] = self.axes_permutation
+ if self.axes_permutation == list(range(len(self.axes_permutation))):
+ axes_reordering = None
+
+ _final_shapes = final_shapes if need_final_reshape else None
+ return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes
+
+
+_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)
+
+
+def _apply_recipe(
+ recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths
+) -> Tensor:
+ # this method implements actual work for all backends for 3 operations
+ try:
+ init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = (
+ _reconstruct_from_shape(recipe, bnp.shape(tensor), axes_lengths))
+ except TypeError:
+ # shape or one of passed axes lengths is not hashable (i.e. they are symbols)
+ _result = _reconstruct_from_shape_uncached(recipe, bnp.shape(tensor), axes_lengths)
+ (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result
+ if init_shapes is not None:
+ tensor = bnp.reshape(bnp.as_jax(tensor), init_shapes)
+ if axes_reordering is not None:
+ tensor = bnp.transpose(bnp.as_jax(tensor), axes_reordering)
+ if len(reduced_axes) > 0:
+ tensor = _reduce_axes(bnp.as_jax(tensor), reduction_type=reduction_type, reduced_axes=reduced_axes)
+ if len(added_axes) > 0:
+ tensor = bnp2.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes)
+ if final_shapes is not None:
+ tensor = bnp.reshape(bnp.as_jax(tensor), final_shapes)
+ return tensor
+
+
+def _apply_recipe_array_api(
+ xp, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths
+) -> Tensor:
+ # completely-inline implementation
+ init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
+ recipe, tensor.shape, axes_lengths
+ )
+ if init_shapes is not None:
+ tensor = xp.reshape(tensor, init_shapes)
+ if axes_reordering is not None:
+ tensor = xp.permute_dims(tensor, axes_reordering)
+ if len(reduced_axes) > 0:
+ if callable(reduction_type):
+ # custom callable
+ tensor = reduction_type(tensor, tuple(reduced_axes))
+ else:
+ # one of built-in operations
+ assert reduction_type in _reductions
+ tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes))
+ if len(added_axes) > 0:
+ # we use broadcasting
+ for axis_position, axis_length in added_axes.items():
+ tensor = xp.expand_dims(tensor, axis=axis_position)
+
+ final_shape = list(tensor.shape)
+ for axis_position, axis_length in added_axes.items():
+ final_shape[axis_position] = axis_length
+
+ tensor = xp.broadcast_to(tensor, final_shape)
+ if final_shapes is not None:
+ tensor = xp.reshape(tensor, final_shapes)
+ return tensor
+
+
+@functools.lru_cache(256)
+def _prepare_transformation_recipe(
+ pattern: str,
+ operation: Reduction,
+ axes_names: Tuple[str, ...],
+ ndim: int,
+) -> TransformRecipe:
+ """Perform initial parsing of pattern and provided supplementary info
+ axes_lengths is a tuple of tuples (axis_name, axis_length)
+ """
+ left_str, rght_str = pattern.split("->")
+ left = ParsedExpression(left_str)
+ rght = ParsedExpression(rght_str)
+
+ # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
+ if not left.has_ellipsis and rght.has_ellipsis:
+ raise EinopsError("Ellipsis found in right side, but not left side of a pattern {}".format(pattern))
+ if left.has_ellipsis and left.has_ellipsis_parenthesized:
+ raise EinopsError("Ellipsis inside parenthesis in the left side is not allowed: {}".format(pattern))
+ if operation == "rearrange":
+ if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
+ raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)")
+ difference = set.symmetric_difference(left.identifiers, rght.identifiers)
+ if len(difference) > 0:
+ raise EinopsError("Identifiers only on one side of expression (should be on both): {}".format(difference))
+ elif operation == "repeat":
+ difference = set.difference(left.identifiers, rght.identifiers)
+ if len(difference) > 0:
+ raise EinopsError("Unexpected identifiers on the left side of repeat: {}".format(difference))
+ axes_without_size = set.difference(
+ {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
+ {*left.identifiers, *axes_names},
+ )
+ if len(axes_without_size) > 0:
+ raise EinopsError("Specify sizes for new axes in repeat: {}".format(axes_without_size))
+ elif operation in _reductions or callable(operation):
+ difference = set.difference(rght.identifiers, left.identifiers)
+ if len(difference) > 0:
+ raise EinopsError("Unexpected identifiers on the right side of reduce {}: {}".format(operation, difference))
+ else:
+ raise EinopsError("Unknown reduction {}. Expect one of {}.".format(operation, _reductions))
+
+ if left.has_ellipsis:
+ n_other_dims = len(left.composition) - 1
+ if ndim < n_other_dims:
+ raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.")
+ ellipsis_ndim = ndim - n_other_dims
+ ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)]
+ left_composition = []
+ for composite_axis in left.composition:
+ if composite_axis == _ellipsis:
+ for axis in ell_axes:
+ left_composition.append([axis])
+ else:
+ left_composition.append(composite_axis)
+
+ rght_composition = []
+ for composite_axis in rght.composition:
+ if composite_axis == _ellipsis:
+ for axis in ell_axes:
+ rght_composition.append([axis])
+ else:
+ group = []
+ for axis in composite_axis:
+ if axis == _ellipsis:
+ group.extend(ell_axes)
+ else:
+ group.append(axis)
+ rght_composition.append(group)
+
+ left.identifiers.update(ell_axes)
+ left.identifiers.remove(_ellipsis)
+ if rght.has_ellipsis:
+ rght.identifiers.update(ell_axes)
+ rght.identifiers.remove(_ellipsis)
+ else:
+ if ndim != len(left.composition):
+ raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.")
+ left_composition = left.composition
+ rght_composition = rght.composition
+
+ # parsing all dimensions to find out lengths
+ axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict()
+ for composite_axis in left_composition:
+ for axis_name in composite_axis:
+ if isinstance(axis_name, AnonymousAxis):
+ axis_name2known_length[axis_name] = axis_name.value
+ else:
+ axis_name2known_length[axis_name] = _unknown_axis_length
+
+ # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point
+
+ repeat_axes_names = []
+ for axis_name in rght.identifiers:
+ if axis_name not in axis_name2known_length:
+ if isinstance(axis_name, AnonymousAxis):
+ axis_name2known_length[axis_name] = axis_name.value
+ else:
+ axis_name2known_length[axis_name] = _unknown_axis_length
+ repeat_axes_names.append(axis_name)
+
+ axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
+
+ # axes provided as kwargs
+ for elementary_axis in axes_names:
+ if not ParsedExpression.check_axis_name(elementary_axis):
+ raise EinopsError("Invalid name for an axis", elementary_axis)
+ if elementary_axis not in axis_name2known_length:
+ raise EinopsError("Axis {} is not used in transform".format(elementary_axis))
+ axis_name2known_length[elementary_axis] = _expected_axis_length
+
+ input_axes_known_unknown = []
+ # some shapes are inferred later - all information is prepared for faster inference
+ for i, composite_axis in enumerate(left_composition):
+ known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
+ unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
+ if len(unknown) > 1:
+ raise EinopsError("Could not infer sizes for {}".format(unknown))
+ assert len(unknown) + len(known) == len(composite_axis)
+ input_axes_known_unknown.append(
+ ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown])
+ )
+
+ axis_position_after_reduction: Dict[str, int] = {}
+ for axis_name in itertools.chain(*left_composition):
+ if axis_name in rght.identifiers:
+ axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
+
+ result_axes_grouping: List[List[int]] = [
+ [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition)
+ ]
+
+ ordered_axis_left = list(itertools.chain(*left_composition))
+ ordered_axis_rght = list(itertools.chain(*rght_composition))
+ reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers]
+ order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes
+ axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition]
+ added_axes = {
+ i: axis_name2position[axis_name]
+ for i, axis_name in enumerate(ordered_axis_rght)
+ if axis_name not in left.identifiers
+ }
+
+ first_reduced_axis = len(order_after_transposition) - len(reduced_axes)
+
+ return TransformRecipe(
+ elementary_axes_lengths=list(axis_name2known_length.values()),
+ axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names},
+ input_composition_known_unknown=input_axes_known_unknown,
+ axes_permutation=axes_permutation,
+ first_reduced_axis=first_reduced_axis,
+ added_axes=added_axes,
+ output_composite_axes=result_axes_grouping,
+ )
+
+
+def _prepare_recipes_for_all_dims(
+ pattern: str, operation: Reduction, axes_names: Tuple[str, ...]
+) -> Dict[int, TransformRecipe]:
+ """
+ Internal function, used in layers.
+ Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims
+ """
+ left_str, rght_str = pattern.split("->")
+ left = ParsedExpression(left_str)
+ dims = [len(left.composition)]
+ if left.has_ellipsis:
+ dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)]
+ return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims}
+
+
+def ein_reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
+ """
+ ``ein_reduce`` provides combination of reordering and reduction using reader-friendly notation.
+
+ Examples for reduce operation:
+
+ ```python
+ >>> x = np.random.randn(100, 32, 64)
+
+ # perform max-reduction on the first axis
+ >>> y = ein_reduce(x, 't b c -> b c', 'max')
+
+ # same as previous, but with clearer axes meaning
+ >>> y = ein_reduce(x, 'time batch channel -> batch channel', 'max')
+
+ >>> x = np.random.randn(10, 20, 30, 40)
+
+ # 2d max-pooling with kernel size = 2 * 2 for image processing
+ >>> y1 = ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
+
+ # if one wants to go back to the original height and width, depth-to-space trick can be applied
+ >>> y2 = ein_rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
+ >>> assert ein_shape(x, 'b _ h w') == ein_shape(y2, 'b _ h w')
+
+ # Adaptive 2d max-pooling to 3 * 4 grid
+ >>> ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
+ (10, 20, 3, 4)
+
+ # Global average pooling
+ >>> ein_reduce(x, 'b c h w -> b c', 'mean').shape
+ (10, 20)
+
+ # Subtracting mean over batch for each channel
+ >>> y = x - ein_reduce(x, 'b c h w -> () c () ()', 'mean')
+
+ # Subtracting per-image mean for each channel
+ >>> y = x - ein_reduce(x, 'b c h w -> b c () ()', 'mean')
+
+ ```
+
+ Parameters:
+ tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
+ list of tensors is also accepted, those should be of the same type and shape
+ pattern: string, reduction pattern
+ reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
+ alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
+ This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc.
+ axes_lengths: any additional specifications for dimensions
+
+ Returns:
+ tensor of the same type as input
+ """
+ try:
+ hashable_axes_lengths = tuple(axes_lengths.items())
+ shape = bnp.shape(tensor)
+ recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape))
+ return _apply_recipe(recipe,
+ cast(Tensor, tensor),
+ reduction_type=reduction,
+ axes_lengths=hashable_axes_lengths)
+ except EinopsError as e:
+ message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
+ if not isinstance(tensor, list):
+ message += "\n Input tensor shape: {}. ".format(shape)
+ else:
+ message += "\n Input is list. "
+ message += "Additional info: {}.".format(axes_lengths)
+ raise EinopsError(message + "\n {}".format(e))
+
+
+def ein_rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor:
+ """
+ ``ein_rearrange`` is a reader-friendly smart element reordering for multidimensional tensors.
+ This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
+ stack, concatenate and other operations.
+
+ Examples for rearrange operation:
+
+ ```python
+ # suppose we have a set of 32 images in "h w c" format (height-width-channel)
+ >>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
+
+ # stack along first (batch) axis, output is a single array
+ >>> ein_rearrange(images, 'b h w c -> b h w c').shape
+ (32, 30, 40, 3)
+
+ # concatenate images along height (vertical axis), 960 = 32 * 30
+ >>> ein_rearrange(images, 'b h w c -> (b h) w c').shape
+ (960, 40, 3)
+
+ # concatenated images along horizontal axis, 1280 = 32 * 40
+ >>> ein_rearrange(images, 'b h w c -> h (b w) c').shape
+ (30, 1280, 3)
+
+ # reordered axes to "b c h w" format for deep learning
+ >>> ein_rearrange(images, 'b h w c -> b c h w').shape
+ (32, 3, 30, 40)
+
+ # flattened each image into a vector, 3600 = 30 * 40 * 3
+ >>> ein_rearrange(images, 'b h w c -> b (c h w)').shape
+ (32, 3600)
+
+ # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
+ >>> ein_rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
+ (128, 15, 20, 3)
+
+ # space-to-depth operation
+ >>> ein_rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
+ (32, 15, 20, 12)
+
+ ```
+
+ When composing axes, C-order enumeration used (consecutive elements have different last axis)
+ Find more examples in einops tutorial.
+
+ Parameters:
+ tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
+ list of tensors is also accepted, those should be of the same type and shape
+ pattern: string, rearrangement pattern
+ axes_lengths: any additional specifications for dimensions
+
+ Returns:
+ tensor of the same type as input. If possible, a view to the original tensor is returned.
+
+ """
+ return ein_reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
+
+
+def ein_repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor:
+ """
+ ``ein_repeat`` allows reordering elements and repeating them in arbitrary combinations.
+ This operation includes functionality of repeat, tile, broadcast functions.
+
+ Examples for repeat operation:
+
+ ```python
+ # a grayscale image (of shape height x width)
+ >>> image = np.random.randn(30, 40)
+
+ # change it to RGB format by repeating in each channel
+ >>> ein_repeat(image, 'h w -> h w c', c=3).shape
+ (30, 40, 3)
+
+ # repeat image 2 times along height (vertical axis)
+ >>> ein_repeat(image, 'h w -> (repeat h) w', repeat=2).shape
+ (60, 40)
+
+ # repeat image 2 time along height and 3 times along width
+ >>> ein_repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape
+ (60, 120)
+
+ # convert each pixel to a small square 2x2. Upsample image by 2x
+ >>> ein_repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
+ (60, 80)
+
+ # pixelate image first by downsampling by 2x, then upsampling
+ >>> downsampled = ein_reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
+ >>> ein_repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
+ (30, 40)
+
+ ```
+
+ When composing axes, C-order enumeration used (consecutive elements have different last axis)
+ Find more examples in einops tutorial.
+
+ Parameters:
+ tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
+ list of tensors is also accepted, those should be of the same type and shape
+ pattern: string, rearrangement pattern
+ axes_lengths: any additional specifications for dimensions
+
+ Returns:
+ Tensor of the same type as input. If possible, a view to the original tensor is returned.
+
+ """
+ return ein_reduce(tensor, pattern, reduction="repeat", **axes_lengths)
+
+
+def ein_shape(x, pattern: str) -> dict:
+ """
+ Parse a tensor shape to dictionary mapping axes names to their lengths.
+
+ ```python
+ # Use underscore to skip the dimension in parsing.
+ >>> x = np.zeros([2, 3, 5, 7])
+ >>> ein_shape(x, 'batch _ h w')
+ {'batch': 2, 'h': 5, 'w': 7}
+
+ # `parse_shape` output can be used to specify axes_lengths for other operations:
+ >>> y = np.zeros([700])
+ >>> ein_rearrange(y, '(b c h w) -> b c h w', **ein_shape(x, 'b _ h w')).shape
+ (2, 10, 5, 7)
+
+ ```
+
+ For symbolic frameworks may return symbols, not integers.
+
+ Parameters:
+ x: tensor of any supported framework
+ pattern: str, space separated names for axes, underscore means skip axis
+
+ Returns:
+ dict, maps axes names to their lengths
+ """
+ exp = ParsedExpression(pattern, allow_underscore=True)
+ shape = bnp.shape(x)
+ if exp.has_composed_axes():
+ raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
+ if len(shape) != len(exp.composition):
+ if exp.has_ellipsis:
+ if len(shape) < len(exp.composition) - 1:
+ raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
+ else:
+ raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
+ if exp.has_ellipsis:
+ ellipsis_idx = exp.composition.index(_ellipsis)
+ composition = (
+ exp.composition[:ellipsis_idx]
+ + ["_"] * (len(shape) - len(exp.composition) + 1)
+ + exp.composition[ellipsis_idx + 1:]
+ )
+ else:
+ composition = exp.composition
+ result = {}
+ for (axis_name,), axis_length in zip(composition, shape): # type: ignore
+ if axis_name != "_":
+ result[axis_name] = axis_length
+ return result
+
+
+# _enumerate_directions is not exposed in the public API
+def _enumerate_directions(x):
+ """
+ For an n-dimensional tensor, returns tensors to enumerate each axis.
+ ```python
+ x = np.zeros([2, 3, 4]) # or any other tensor
+ i, j, k = _enumerate_directions(x)
+ result = i + 2*j + 3*k
+ ```
+
+ `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result
+ Works very similarly to numpy.ogrid (open indexing grid)
+ """
+ shape = bnp.shape(x)
+ result = []
+ for axis_id, axis_length in enumerate(shape):
+ shape = [1] * len(shape)
+ shape[axis_id] = axis_length
+ result.append(bnp.reshape(bnp.arange(0, axis_length), shape))
+ return result
diff --git a/brainpy/_src/math/einops_parsing.py b/brainpy/_src/math/einops_parsing.py
new file mode 100644
index 000000000..6ce055bdb
--- /dev/null
+++ b/brainpy/_src/math/einops_parsing.py
@@ -0,0 +1,153 @@
+import keyword
+import warnings
+from typing import List, Optional, Set, Tuple, Union
+
+_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
+
+
+class EinopsError(Exception):
+ pass
+
+
+class AnonymousAxis(object):
+ """Important thing: all instances of this class are not equal to each other """
+
+ def __init__(self, value: str):
+ self.value = int(value)
+ if self.value <= 1:
+ if self.value == 1:
+ raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
+ else:
+ raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
+
+ def __repr__(self):
+ return "{}-axis".format(str(self.value))
+
+
+class ParsedExpression:
+ """
+ non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
+ and keeps some information important for downstream
+ """
+
+ def __init__(self, expression: str, *, allow_underscore: bool = False,
+ allow_duplicates: bool = False):
+ self.has_ellipsis: bool = False
+ self.has_ellipsis_parenthesized: Optional[bool] = None
+ self.identifiers: Set[str] = set()
+ # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
+ self.has_non_unitary_anonymous_axes: bool = False
+ # composition keeps structure of composite axes, see how different corner cases are handled in tests
+ self.composition: List[Union[List[str], str]] = []
+ if '.' in expression:
+ if '...' not in expression:
+ raise EinopsError('Expression may contain dots only inside ellipsis (...)')
+ if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
+ raise EinopsError(
+ 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
+ expression = expression.replace('...', _ellipsis)
+ self.has_ellipsis = True
+
+ bracket_group: Optional[List[str]] = None
+
+ def add_axis_name(x):
+ if x in self.identifiers:
+ if not (allow_underscore and x == "_") and not allow_duplicates:
+ raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
+ if x == _ellipsis:
+ self.identifiers.add(_ellipsis)
+ if bracket_group is None:
+ self.composition.append(_ellipsis)
+ self.has_ellipsis_parenthesized = False
+ else:
+ bracket_group.append(_ellipsis)
+ self.has_ellipsis_parenthesized = True
+ else:
+ is_number = str.isdecimal(x)
+ if is_number and int(x) == 1:
+ # handling the case of anonymous axis of length 1
+ if bracket_group is None:
+ self.composition.append([])
+ else:
+ pass # no need to think about 1s inside parenthesis
+ return
+ is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
+ if not (is_number or is_axis_name):
+ raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
+ if is_number:
+ x = AnonymousAxis(x)
+ self.identifiers.add(x)
+ if is_number:
+ self.has_non_unitary_anonymous_axes = True
+ if bracket_group is None:
+ self.composition.append([x])
+ else:
+ bracket_group.append(x)
+
+ current_identifier = None
+ for char in expression:
+ if char in '() ':
+ if current_identifier is not None:
+ add_axis_name(current_identifier)
+ current_identifier = None
+ if char == '(':
+ if bracket_group is not None:
+ raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
+ bracket_group = []
+ elif char == ')':
+ if bracket_group is None:
+ raise EinopsError('Brackets are not balanced')
+ self.composition.append(bracket_group)
+ bracket_group = None
+ elif str.isalnum(char) or char in ['_', _ellipsis]:
+ if current_identifier is None:
+ current_identifier = char
+ else:
+ current_identifier += char
+ else:
+ raise EinopsError("Unknown character '{}'".format(char))
+
+ if bracket_group is not None:
+ raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
+ if current_identifier is not None:
+ add_axis_name(current_identifier)
+
+ def flat_axes_order(self) -> List:
+ result = []
+ for composed_axis in self.composition:
+ assert isinstance(composed_axis, list), 'does not work with ellipsis'
+ for axis in composed_axis:
+ result.append(axis)
+ return result
+
+ def has_composed_axes(self) -> bool:
+ # this will ignore 1 inside brackets
+ for axes in self.composition:
+ if isinstance(axes, list) and len(axes) > 1:
+ return True
+ return False
+
+ @staticmethod
+ def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
+ if not str.isidentifier(name):
+ return False, 'not a valid python identifier'
+ elif name[0] == '_' or name[-1] == '_':
+ if name == '_' and allow_underscore:
+ return True, ''
+ return False, 'axis name should should not start or end with underscore'
+ else:
+ if keyword.iskeyword(name):
+ warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
+ if name in ['axis']:
+ warnings.warn("It is discouraged to use 'axis' as an axis name "
+ "and will raise an error in future", FutureWarning)
+ return True, ''
+
+ @staticmethod
+ def check_axis_name(name: str) -> bool:
+ """
+ Valid axes names are python identifiers except keywords,
+ and additionally should not start or end with underscore
+ """
+ is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
+ return is_valid
diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py
index eef0361fc..1c8b98a3b 100644
--- a/brainpy/_src/math/environment.py
+++ b/brainpy/_src/math/environment.py
@@ -9,13 +9,16 @@
import warnings
from typing import Any, Callable, TypeVar, cast
+import jax
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge
from . import modes
from . import scales
+from . import defaults
+from brainpy._src.dependency_check import import_taichi
-bm = None
+ti = import_taichi()
__all__ = [
# context manage for environment setting
@@ -388,9 +391,7 @@ def ditype():
"""
# raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n'
# 'Use `brainpy.math.int_` instead.')
- global bm
- if bm is None: from brainpy import math as bm
- return bm.int_
+ return defaults.int_
def dftype():
@@ -402,9 +403,7 @@ def dftype():
# raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n'
# 'Use `brainpy.math.float_` instead.')
- global bm
- if bm is None: from brainpy import math as bm
- return bm.float_
+ return defaults.float_
def set_float(dtype: type):
@@ -415,11 +414,17 @@ def set_float(dtype: type):
dtype: type
The float type.
"""
- if dtype not in [jnp.float16, jnp.float32, jnp.float64, ]:
- raise TypeError(f'Float data type {dtype} is not supported.')
- global bm
- if bm is None: from brainpy import math as bm
- bm.__dict__['float_'] = dtype
+ if dtype in [jnp.float16, 'float16', 'f16']:
+ defaults.__dict__['float_'] = jnp.float16
+ defaults.__dict__['ti_float'] = ti.float16
+ elif dtype in [jnp.float32, 'float32', 'f32']:
+ defaults.__dict__['float_'] = jnp.float32
+ defaults.__dict__['ti_float'] = ti.float32
+ elif dtype in [jnp.float64, 'float64', 'f64']:
+ defaults.__dict__['float_'] = jnp.float64
+ defaults.__dict__['ti_float'] = ti.float64
+ else:
+ raise NotImplementedError
def get_float():
@@ -430,9 +435,7 @@ def get_float():
dftype: type
The default float data type.
"""
- global bm
- if bm is None: from brainpy import math as bm
- return bm.float_
+ return defaults.float_
def set_int(dtype: type):
@@ -443,12 +446,20 @@ def set_int(dtype: type):
dtype: type
The integer type.
"""
- if dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
- jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, ]:
- raise TypeError(f'Integer data type {dtype} is not supported.')
- global bm
- if bm is None: from brainpy import math as bm
- bm.__dict__['int_'] = dtype
+ if dtype in [jnp.int8, 'int8', 'i8']:
+ defaults.__dict__['int_'] = jnp.int8
+ defaults.__dict__['ti_int'] = ti.int8
+ elif dtype in [jnp.int16, 'int16', 'i16']:
+ defaults.__dict__['int_'] = jnp.int16
+ defaults.__dict__['ti_int'] = ti.int16
+ elif dtype in [jnp.int32, 'int32', 'i32']:
+ defaults.__dict__['int_'] = jnp.int32
+ defaults.__dict__['ti_int'] = ti.int32
+ elif dtype in [jnp.int64, 'int64', 'i64']:
+ defaults.__dict__['int_'] = jnp.int64
+ defaults.__dict__['ti_int'] = ti.int64
+ else:
+ raise NotImplementedError
def get_int():
@@ -459,9 +470,7 @@ def get_int():
dftype: type
The default int data type.
"""
- global bm
- if bm is None: from brainpy import math as bm
- return bm.int_
+ return defaults.int_
def set_bool(dtype: type):
@@ -472,9 +481,7 @@ def set_bool(dtype: type):
dtype: type
The bool type.
"""
- global bm
- if bm is None: from brainpy import math as bm
- bm.__dict__['bool_'] = dtype
+ defaults.__dict__['bool_'] = dtype
def get_bool():
@@ -485,9 +492,7 @@ def get_bool():
dftype: type
The default bool data type.
"""
- global bm
- if bm is None: from brainpy import math as bm
- return bm.bool_
+ return defaults.bool_
def set_complex(dtype: type):
@@ -498,9 +503,7 @@ def set_complex(dtype: type):
dtype: type
The complex type.
"""
- global bm
- if bm is None: from brainpy import math as bm
- bm.__dict__['complex_'] = dtype
+ defaults.__dict__['complex_'] = dtype
def get_complex():
@@ -511,9 +514,7 @@ def get_complex():
dftype: type
The default complex data type.
"""
- global bm
- if bm is None: from brainpy import math as bm
- return bm.complex_
+ return defaults.complex_
# numerical precision
@@ -528,9 +529,7 @@ def set_dt(dt):
Numerical integration precision.
"""
assert isinstance(dt, float), f'"dt" must a float, but we got {dt}'
- global bm
- if bm is None: from brainpy import math as bm
- bm.__dict__['dt'] = dt
+ defaults.__dict__['dt'] = dt
def get_dt():
@@ -541,9 +540,7 @@ def get_dt():
dt : float
Numerical integration precision.
"""
- global bm
- if bm is None: from brainpy import math as bm
- return bm.dt
+ return defaults.dt
def set_mode(mode: modes.Mode):
@@ -557,9 +554,7 @@ def set_mode(mode: modes.Mode):
if not isinstance(mode, modes.Mode):
raise TypeError(f'Must be instance of brainpy.math.Mode. '
f'But we got {type(mode)}: {mode}')
- global bm
- if bm is None: from brainpy import math as bm
- bm.__dict__['mode'] = mode
+ defaults.__dict__['mode'] = mode
def get_mode() -> modes.Mode:
@@ -570,9 +565,7 @@ def get_mode() -> modes.Mode:
mode: Mode
The default computing mode.
"""
- global bm
- if bm is None: from brainpy import math as bm
- return bm.mode
+ return defaults.mode
def set_membrane_scaling(membrane_scaling: scales.Scaling):
@@ -586,9 +579,7 @@ def set_membrane_scaling(membrane_scaling: scales.Scaling):
if not isinstance(membrane_scaling, scales.Scaling):
raise TypeError(f'Must be instance of brainpy.math.Scaling. '
f'But we got {type(membrane_scaling)}: {membrane_scaling}')
- global bm
- if bm is None: from brainpy import math as bm
- bm.__dict__['membrane_scaling'] = membrane_scaling
+ defaults.__dict__['membrane_scaling'] = membrane_scaling
def get_membrane_scaling() -> scales.Scaling:
@@ -599,9 +590,7 @@ def get_membrane_scaling() -> scales.Scaling:
membrane_scaling: Scaling
The default computing membrane_scaling.
"""
- global bm
- if bm is None: from brainpy import math as bm
- return bm.membrane_scaling
+ return defaults.membrane_scaling
def enable_x64(x64=None):
@@ -682,7 +671,11 @@ def set_host_device_count(n):
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
-def clear_buffer_memory(platform=None):
+def clear_buffer_memory(
+ platform: str = None,
+ array: bool = True,
+ compilation: bool = False
+):
"""Clear all on-device buffers.
This function will be very useful when you call models in a Python loop,
@@ -697,18 +690,47 @@ def clear_buffer_memory(platform=None):
----------
platform: str
The device to clear its memory.
+ array: bool
+ Clear all buffer array.
+ compilation: bool
+ Clear compilation cache.
+
"""
- for buf in xla_bridge.get_backend(platform=platform).live_buffers():
- buf.delete()
+ if array:
+ for buf in xla_bridge.get_backend(platform).live_buffers():
+ buf.delete()
+ if compilation:
+ jax.clear_caches()
-def disable_gpu_memory_preallocation():
- """Disable pre-allocating the GPU memory."""
+def disable_gpu_memory_preallocation(release_memory: bool = True):
+ """Disable pre-allocating the GPU memory.
+
+ This disables the preallocation behavior. JAX will instead allocate GPU memory as needed,
+ potentially decreasing the overall memory usage. However, this behavior is more prone to
+ GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory
+ may OOM with preallocation disabled.
+
+ Args:
+ release_memory: bool. Whether we release memory during the computation.
+ """
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
- os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
+ if release_memory:
+ os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
def enable_gpu_memory_preallocation():
"""Disable pre-allocating the GPU memory."""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
- os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR')
+ os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None)
+
+
+def gpu_memory_preallocation(percent: float):
+ """GPU memory allocation.
+
+ If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory,
+ instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
+ """
+ assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.'
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent)
+
diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py
index 9da0cf524..6e03be463 100644
--- a/brainpy/_src/math/event/_csr_matvec.py
+++ b/brainpy/_src/math/event/_csr_matvec.py
@@ -10,7 +10,6 @@
"""
-
from functools import partial
from typing import Union, Tuple
@@ -22,20 +21,69 @@
from jax.interpreters import ad, xla
from jax.lib import xla_client
+from brainpy._src.dependency_check import (import_brainpylib_gpu_ops)
+from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import (compile_cpu_signature_with_numba,
- register_general_batching)
-from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv
+ register_general_batching,
+ XLACustomOp)
+from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv
+from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._utils import csr_to_coo
-from brainpy._src.dependency_check import (import_brainpylib_gpu_ops)
from brainpy.errors import GPUOperatorNotFound
__all__ = [
'csrmv'
]
+ti = import_taichi()
+
def csrmv(
+ data: Union[float, jax.Array],
+ indices: jax.Array,
+ indptr: jax.Array,
+ events: jax.Array,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+) -> jax.Array:
+ """Product of a sparse CSR matrix and a dense event vector.
+
+ This function supports JAX transformations, including `jit()`, `grad()`,
+ `vmap()` and `pmap()`.
+
+ Parameters
+ ----------
+ data: ndarray, float
+ An array of shape ``(nse,)``.
+ indices: ndarray
+ An array of shape ``(nse,)``.
+ indptr: ndarray
+ An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
+ events: ndarray
+ An array of shape ``(shape[0] if transpose else shape[1],)``
+ and dtype ``data.dtype``.
+ shape: tuple
+ A length-2 tuple representing the matrix shape.
+ transpose: bool
+ A boolean specifying whether to transpose the sparse matrix
+ before computing.
+ If ``transpose=True``, the operator will compute based on the
+ event-driven property of the ``events`` vector.
+
+ Returns
+ -------
+ y : Array
+ The array of shape ``(shape[1] if transpose else shape[0],)`` representing
+ the matrix vector product.
+ """
+ return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)
+
+
+### BRAINPYLIB ###
+
+def csrmv_brainpylib(
data: Union[float, jax.Array],
indices: jax.Array,
indptr: jax.Array,
@@ -304,7 +352,7 @@ def _f(ct, indices, indptr, events, *, transpose):
event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching')
event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract)
event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p))
-xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation
+# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation
ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values,
None, None, _batch_event_csr_matvec_jvp_events)
ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose
@@ -519,15 +567,15 @@ def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose):
return r, 0
-def _event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, shape, transpose):
- return csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose)
+def _event_csr_matvec_jvp_values_brainpylib(values_dot, values, indices, indptr, events, *, shape, transpose):
+ return normal_csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose)
-def _event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, shape, transpose):
+def _event_csr_matvec_jvp_events_brainpylib(events_dot, values, indices, indptr, events, *, shape, transpose):
return normal_csrmv(values, indices, indptr, events_dot, shape=shape, transpose=transpose)
-def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, transpose):
+def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, *, shape, transpose):
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
if ad.is_undefined_primal(events):
@@ -538,7 +586,7 @@ def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, t
ct_values = ad.Zero(values)
else:
if values.aval.shape[0] == 1: # scalar
- ct_values = csrmv(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)
+ ct_values = csrmv_brainpylib(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)
ct_values = jnp.inner(ct, ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
@@ -549,9 +597,493 @@ def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, t
event_csr_matvec_p = Primitive('event_csr_matvec')
event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract)
event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p))
-xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation
-xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation
-ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values, None, None, _event_csr_matvec_jvp_events)
-ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose
+# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation
+# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation
+ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None,
+ _event_csr_matvec_jvp_events_brainpylib)
+ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib
register_general_batching(event_csr_matvec_p)
+
+
# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule
+
+
+### TAICHI ###
+
+def csrmv_taichi(
+ data: Union[float, jax.Array],
+ indices: jax.Array,
+ indptr: jax.Array,
+ events: jax.Array,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False
+) -> jax.Array:
+ """Product of a sparse CSR matrix and a dense event vector.
+
+ This function supports JAX transformations, including `jit()`, `grad()`,
+ `vmap()` and `pmap()`.
+
+ Parameters
+ ----------
+ data: ndarray, float
+ An array of shape ``(nse,)``.
+ indices: ndarray
+ An array of shape ``(nse,)``.
+ indptr: ndarray
+ An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
+ events: ndarray
+ An array of shape ``(shape[0] if transpose else shape[1],)``
+ and dtype ``data.dtype``.
+ shape: tuple
+ A length-2 tuple representing the matrix shape.
+ transpose: bool
+ A boolean specifying whether to transpose the sparse matrix
+ before computing.
+ If ``transpose=True``, the operator will compute based on the
+ event-driven property of the ``events`` vector.
+
+ Returns
+ -------
+ y : Array
+ The array of shape ``(shape[1] if transpose else shape[0],)`` representing
+ the matrix vector product.
+ """
+ data = as_jax(data)
+ indices = as_jax(indices)
+ indptr = as_jax(indptr)
+ events = as_jax(events)
+
+ # checking
+ data = jnp.atleast_1d(data)
+ if np.ndim(data) == 1:
+ if data.shape[0] not in [1, indices.shape[0]]:
+ raise ValueError('The size of data should be 1 or be consistent with indices.'
+ f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.')
+ else:
+ raise ValueError('data should be a scalar or 1D vector. '
+ f'But we got {np.ndim(data)}-D array.')
+ if np.ndim(indices) != 1:
+ raise ValueError('indices should be a 1D vector with integer type.')
+ if np.ndim(indptr) != 1:
+ raise ValueError('indptr should be a 1D vector with integer type.')
+ if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]:
+ raise ValueError(
+ 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.')
+ if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]:
+ raise ValueError(
+ 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.')
+ if np.ndim(events) != 1:
+ raise ValueError('events should be a 1D vector.')
+ if len(shape) != 2:
+ raise ValueError('shape should be a length-2 tuple.')
+ if transpose:
+ if events.shape[0] != shape[0]:
+ raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.')
+ else:
+ if events.shape[0] != shape[1]:
+ raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).')
+
+ # if the shape of indices is (0,), then we return a zero vector
+ if indices.shape[0] == 0:
+ return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)
+
+ return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]
+
+
+# -------------
+# CPU operators
+# -------------
+
+# 1. The benchmarking shows that the performance of the following transpose
+# kernels is maximized when using serialized mode
+# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable
+# arguments, we have to define each kernel separately when the
+# non-differentiable/non-jittable arguments are different.
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ if events[row_i]:
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ out[indices[j]] += value
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ if events[row_i]:
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ out[indices[j]] += values[j]
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ if events[row_i] != 0.:
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ out[indices[j]] += value
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ if events[row_i] != 0.:
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ out[indices[j]] += values[j]
+
+
+@ti.kernel
+def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ # ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ r = 0.
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ if events[indices[j]]:
+ r += value
+ out[row_i] = r
+
+
+@ti.kernel
+def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ # ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ r = 0.
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ if events[indices[j]]:
+ r += values[j]
+ out[row_i] = r
+
+
+@ti.kernel
+def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ # ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ r = 0.
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ if events[indices[j]] != 0.:
+ r += value
+ out[row_i] = r
+
+
+@ti.kernel
+def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ # ti.loop_config(serialize=True)
+ for row_i in range(indptr.shape[0] - 1):
+ r = 0.
+ for j in range(indptr[row_i], indptr[row_i + 1]):
+ if events[indices[j]] != 0.:
+ r += values[j]
+ out[row_i] = r
+
+
+# -------------
+# GPU operators
+# -------------
+
+# 1. GPU kernels are different from the CPU ones, since the GPU kernels need
+# to use warp-level parallelism to achieve the best performance.
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ if events[row_i]:
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ out[indices[j]] += value
+ j += 32
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ if events[row_i] != 0.:
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ out[indices[j]] += value
+ j += 32
+
+
+# TODO
+# It is important to note that the following warp-based kernels
+# should be improved, since the atomic_add for each thread is not
+# very efficient. Instead, the warp-level reduction primitive
+# should be used.
+# see ``warp_reduce_sum()`` function in tifunc.py.
+# However, currently Taichi does not support general warp-level primitives.
+
+
+@ti.kernel
+def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ r = 0.
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ if events[indices[j]]:
+ r += value
+ j += 32
+ out[row_i] += r # TODO: warp-level primitive
+
+
+@ti.kernel
+def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ r = 0.
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ if events[indices[j]] != 0.:
+ r += value
+ j += 32
+ out[row_i] += r # TODO: warp-level primitive
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ if events[row_i]:
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ out[indices[j]] += values[j]
+ j += 32
+
+
+@ti.kernel
+def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ if events[row_i] != 0.:
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ out[indices[j]] += values[j]
+ j += 32
+
+
+@ti.kernel
+def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ r = 0.
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ if events[indices[j]]:
+ r += values[j]
+ j += 32
+ out[row_i] += r # TODO: warp-level primitive
+
+
+@ti.kernel
+def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
+ indices: ti.types.ndarray(ndim=1),
+ indptr: ti.types.ndarray(ndim=1),
+ events: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ for i in range((indptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ r = 0.
+ j = indptr[row_i] + index
+ end_index = indptr[row_i + 1]
+ while j < end_index:
+ if events[indices[j]] != 0.:
+ r += values[j]
+ j += 32
+ out[row_i] += r # TODO: warp-level primitive
+
+
+def raw_csrmv_taichi(
+ data: Union[float, jax.Array],
+ indices: jax.Array,
+ indptr: jax.Array,
+ events: jax.Array,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False
+):
+ if transpose:
+ if events.dtype == jnp.bool_:
+ if data.shape[0] == 1:
+ prim = _event_csrmv_transpose_bool_homo_p
+ else:
+ prim = _event_csrmv_transpose_bool_heter_p
+ else:
+ if data.shape[0] == 1:
+ prim = _event_csrmv_transpose_homo_p
+ else:
+ prim = _event_csrmv_transpose_heter_p
+ else:
+ if events.dtype == jnp.bool_:
+ if data.shape[0] == 1:
+ prim = _event_csrmv_bool_homo_p
+ else:
+ prim = _event_csrmv_bool_heter_p
+ else:
+ if data.shape[0] == 1:
+ prim = _event_csrmv_homo_p
+ else:
+ prim = _event_csrmv_heter_p
+
+ # computing
+ return prim(data,
+ indices,
+ indptr,
+ events,
+ outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)],
+ transpose=transpose,
+ shape=shape)
+
+
+def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
+ return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)
+
+
+def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape):
+ return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose)
+
+
+def _event_csr_matvec_transpose_taichi(
+ ct, values, indices, indptr, events, *, outs, transpose, shape
+):
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
+ if ad.is_undefined_primal(events):
+ ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0]
+ return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events)
+ else:
+ if type(ct[0]) is ad.Zero:
+ ct_values = ad.Zero(values)
+ else:
+ if values.aval.shape[0] == 1: # scalar
+ ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
+ ct_values = jnp.inner(ct[0], ct_values)
+ else: # heterogeneous values
+ row, col = csr_to_coo(indices, indptr)
+ ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
+ return ct_values, indices, indptr, events
+
+
+def _define_op(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi)
+ prim.def_transpose_rule(_event_csr_matvec_transpose_taichi)
+ return prim
+
+
+# transpose bool homo
+_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu,
+ _event_csr_matvec_transpose_bool_homo_gpu)
+
+# transpose homo
+_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu)
+
+# not transpose bool homo
+_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu)
+
+# not transpose homo
+_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu)
+
+# transpose bool heter
+_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu,
+ _event_csr_matvec_transpose_bool_heter_gpu)
+
+# transpose heter
+_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu,
+ _event_csr_matvec_transpose_heter_gpu)
+
+# not transpose bool heter
+_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu)
+
+# not transpose heter
+_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu)
diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py
index 9f8a5f31a..7bb043e3e 100644
--- a/brainpy/_src/math/event/_info_collection.py
+++ b/brainpy/_src/math/event/_info_collection.py
@@ -6,15 +6,16 @@
import numba
from jax import dtypes, numpy as jnp
from jax.core import ShapedArray
-from jax.interpreters import batching
from jax.lib import xla_client
+from brainpy._src.dependency_check import import_brainpylib_gpu_ops
+from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import register_op_with_numba
from brainpy._src.math.ndarray import Array
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops
+from brainpy._src.math.op_register.base import XLACustomOp
from brainpy.errors import GPUOperatorNotFound
+ti = import_taichi()
__all__ = [
'info'
@@ -40,7 +41,7 @@ def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]:
events = as_jax(events)
if events.ndim != 1:
raise TypeError('Only support 1D boolean vector.')
- return event_info_p.bind(events)
+ return event_info_p(events)
def _batch_event_info_abstract(events):
@@ -66,11 +67,26 @@ def _batch_event_info(outs, ins):
event_num[batch_idx] = num
+@ti.kernel
+def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2),
+ event_ids: ti.types.ndarray(ndim=2),
+ event_num: ti.types.ndarray(ndim=1)):
+ for i, j in ti.grouped(ti.ndrange(event_ids.shape)):
+ event_ids[i, j] = -1
+ for batch_idx in range(event_ids.shape[0]):
+ num = 0
+ for i in range(event_ids.shape[1]):
+ if events[batch_idx, i]:
+ event_ids[batch_idx, num] = i
+ num += 1
+ event_num[batch_idx] = num
+
+
def _batch_event_info_batching_rule(args, axes):
arg = jnp.moveaxis(args[0], axes[0], 0)
shape = arg.shape
arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2]))
- event_ids, event_num = batch_event_info_p.bind(arg)
+ event_ids, event_num = batch_event_info_p(arg)
return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])),
(0, 0))
@@ -121,17 +137,16 @@ def _event_info_gpu_translation(c, events):
)
-batch_event_info_p = register_op_with_numba(
- op_name='event_info',
- cpu_func=_batch_event_info,
- out_shapes=_batch_event_info_abstract,
- gpu_func_translation=_event_info_gpu_translation,
- multiple_results=True
+batch_event_info_p = XLACustomOp(
+ name='batched_event_info',
+ cpu_kernel=_batch_event_info_taichi,
+ gpu_kernel=_batch_event_info_taichi,
+ outs=_batch_event_info_abstract,
)
-batching.primitive_batchers[batch_event_info_p] = _batch_event_info_batching_rule
+batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule)
-def _event_info_abstract(events):
+def _event_info_abstract(events, **kwargs):
assert events.ndim == 1
# assert events.dtype == jnp.bool_
event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape)
@@ -154,16 +169,30 @@ def _event_info(outs, ins):
event_num[0] = num
+@ti.kernel
+def _event_info_taichi(events: ti.types.ndarray(ndim=1),
+ event_ids: ti.types.ndarray(ndim=1),
+ event_num: ti.types.ndarray(ndim=1)):
+ for i in range(event_ids.shape[0]):
+ event_ids[i] = -1
+ num = 0
+ for i in range(event_ids.shape[0]):
+ if events[i]:
+ event_ids[num] = i
+ num += 1
+ event_num[0] = num
+
+
def _event_info_batching_rule(args, axes):
arg = jnp.moveaxis(args[0], axes[0], 0)
- return (batch_event_info_p.bind(arg), (0, 0))
+ return (batch_event_info_p(arg), (0, 0))
-event_info_p = register_op_with_numba(
- op_name='event_info',
- cpu_func=_event_info,
- out_shapes=_event_info_abstract,
- gpu_func_translation=_event_info_gpu_translation,
- multiple_results=True
+event_info_p = XLACustomOp(
+ name='event_info',
+ cpu_kernel=_event_info_taichi,
+ gpu_kernel=_event_info_taichi,
+ outs=_event_info_abstract,
+ # gpu_func_translation=_event_info_gpu_translation,
)
-batching.primitive_batchers[event_info_p] = _event_info_batching_rule
+event_info_p.def_batching_rule(_event_info_batching_rule)
diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py
new file mode 100644
index 000000000..3ac1e0ee2
--- /dev/null
+++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py
@@ -0,0 +1,254 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('cpu')
+
+s = [1000, 5000, 10000, 20000, 25000, 30000]
+p = [0.1, 0.2, 0.3, 0.4, 0.5]
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+]
+
+
+
+values_type = [
+ 'homo',
+ 'heter'
+ ]
+events_type = [
+ 'bool',
+ 'float',
+ ]
+transpose = [
+ True,
+ False
+ ]
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+print(bm.get_platform())
+
+@partial(jax.jit, static_argnums=(4, 5))
+def event_csrmv_taichi(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0]
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5))
+def event_csrmv(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)
+ return r
+
+def test_event_csrmv(shape, values_type, events_type, transpose):
+ rng = bm.random.RandomState(seed=1234)
+ indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
+ vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ weight = 1.
+
+
+ if events_type == 'float':
+ vector = vector.astype(bm.float32)
+ if values_type == 'heter':
+ heter_data = bm.ones(indices.shape) * weight
+ weight = heter_data
+
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time0 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time19 = time.time()
+
+
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time20 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+ # assert(jnp.allclose(result1[0], result2))
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/event_csrmv_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/event_csrmv_gpu.csv', index=False)
diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py
new file mode 100644
index 000000000..98793e600
--- /dev/null
+++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py
@@ -0,0 +1,271 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('cpu')
+
+s = [1000, 5000, 10000, 20000, 25000, 30000]
+p = [0.1, 0.2, 0.3, 0.4, 0.5]
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+]
+
+
+
+values_type = [
+ 'homo',
+ 'heter'
+ ]
+events_type = [
+ 'bool',
+ 'float',
+ ]
+transpose = [
+ True,
+ False
+ ]
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+print(bm.get_platform())
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)
+ return r.sum()
+
+ return func
+
+
+def sum_op2(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)[0]
+ return r.sum()
+
+ return func
+
+@partial(jax.jit, static_argnums=(4, 5))
+def event_csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)(
+ weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5))
+def event_csrmv_grad(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.event.csrmv), argnums=3)(
+ weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ return r
+
+
+def test_event_csrmv(shape, values_type, events_type, transpose):
+ rng = bm.random.RandomState(seed=1234)
+ indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
+ vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ weight = 1.
+
+
+ if events_type == 'float':
+ vector = vector.astype(bm.float32)
+ if values_type == 'heter':
+ heter_data = bm.ones(indices.shape) * weight
+ weight = heter_data
+
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/event_csrmv_grad_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/event_csrmv_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py
index a2374d487..e0f38490f 100644
--- a/brainpy/_src/math/event/tests/test_event_csrmv.py
+++ b/brainpy/_src/math/event/tests/test_event_csrmv.py
@@ -8,13 +8,8 @@
import brainpy as bp
import brainpy.math as bm
-import platform
-import pytest
-
-is_manual_test = False
-if platform.system() == 'Windows' and not is_manual_test:
- pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
+seed = 1234
def sum_op(op):
@@ -24,127 +19,92 @@ def func(*args, **kwargs):
return func
+taichi_csr_matvec = bm.event.csrmv
-class Test_event_csr_matvec(parameterized.TestCase):
+class Test_event_csr_matvec_taichi(parameterized.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_event_csr_matvec, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
+ super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs)
+
print()
+ bm.set_platform(platform)
- @parameterized.named_parameters(
- dict(
- testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}',
- transpose=transpose,
- shape=shape,
- homo_data=homo_data,
- )
- for transpose in [True, False]
- for shape in [(100, 200),
- (200, 200),
- (200, 100),
- (10, 1000),
- (2, 10000),
- (1000, 10),
- (10000, 2)]
- for homo_data in [-1., 0., 1.]
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000)],
+ homo_data=[-1., 0., 1.],
)
- def test_homo(self, shape, transpose, homo_data):
+ def test_homo(self, transpose, shape, homo_data):
print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
-
- rng = bm.random.RandomState()
+ rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
events = rng.random(shape[0] if transpose else shape[1]) < 0.1
heter_data = bm.ones(indices.shape) * homo_data
- r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
- r2 = bm.event.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r2))
-
- r3 = bm.event.csrmv(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r3))
-
dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r4 = (events @ dense) if transpose else (dense @ events)
- self.assertTrue(bm.allclose(r1, r4))
+ r1 = (events @ dense) if transpose else (dense @ events)
+ r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
- r5 = bm.event.csrmv(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r5))
+ assert (bm.allclose(r1, r2))
bm.clear_buffer_memory()
- @parameterized.named_parameters(
- dict(
- testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}',
- transpose=transpose,
- shape=shape,
- homo_data=homo_data,
- )
- for transpose in [True, False]
- for shape in [(100, 200),
- (200, 200),
- (200, 100),
- (10, 1000),
- (2, 10000),
- (1000, 10),
- (100000, 2)]
- for homo_data in [-1., 0., 1.]
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000)],
+ homo_data=[-1., 0., 1.],
)
- def test_homo_vamp(self, shape, transpose, homo_data):
+ def test_homo_vmap(self, shape, transpose, homo_data):
print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
- rng = bm.random.RandomState()
+ rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
# vmap 'data'
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events,
+ f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events,
+ shape=shape, transpose=transpose))
+ f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
- f2 = jax.vmap(
- partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float),
- shape=shape, transpose=transpose))
vmap_data = bm.as_jax([homo_data] * 10)
self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))
# vmap 'events'
- f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr,
+ f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr,
shape=shape, transpose=transpose))
- f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr,
+ f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
- self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float))))
+ self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))
# vmap 'data' and 'events'
- f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose))
- f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose,
- method='cusparse'))
+ f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose))
+ f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose))
+
vmap_data1 = bm.as_jax([homo_data] * 10)
vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
- f6(vmap_data1, vmap_data2.astype(float))))
+ f6(vmap_data1, vmap_data2)))
bm.clear_buffer_memory()
- @parameterized.named_parameters(
- dict(
- testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}',
- homo_data=homo_data,
- shape=shape,
- transpose=transpose,
- )
- for transpose in [True, False]
- for shape in [(100, 200),
- (200, 200),
- (200, 100),
- (10, 1000),
- (2, 10000),
- (1000, 10),
- (100000, 2)]
- for homo_data in [-1., 0., 1.]
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000)],
+ homo_data=[-1., 0., 1.],
)
def test_homo_grad(self, shape, transpose, homo_data):
print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
- rng = bm.random.RandomState()
+ rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
@@ -152,140 +112,102 @@ def test_homo_grad(self, shape, transpose, homo_data):
dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)
# grad 'data'
- r1 = jax.grad(sum_op(bm.event.csrmv))(
+ r1 = jax.grad(sum_op(bm.sparse.csrmv))(
+ homo_data, indices, indptr, events, shape=shape, transpose=transpose)
+ r2 = jax.grad(sum_op(taichi_csr_matvec))(
homo_data, indices, indptr, events, shape=shape, transpose=transpose)
- r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))(
- homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))
- r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else
- ((dense_conn * a) @ events))))(homo_data)
- self.assertTrue(bm.allclose(r1, r3))
# grad 'events'
- r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(
+ r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)(
+ r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)(
homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else
- ((dense_conn * homo_data) @ e))))(events.astype(float))
- self.assertTrue(bm.allclose(r4, r5))
- self.assertTrue(bm.allclose(r4, r6))
+ self.assertTrue(bm.allclose(r3, r4))
bm.clear_buffer_memory()
- @parameterized.named_parameters(
- dict(
- testcase_name=f'transpose={transpose}, shape={shape}',
- shape=shape,
- transpose=transpose,
- )
- for transpose in [True, False]
- for shape in [(100, 200),
- (200, 200),
- (200, 100),
- (10, 1000),
- (2, 10000),
- (1000, 10),
- (10000, 2)]
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000), ]
)
def test_heter(self, shape, transpose):
print(f'test_heter: shape = {shape}, transpose = {transpose}')
-
- rng = bm.random.RandomState()
+ rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
heter_data = bm.as_jax(rng.random(indices.shape))
- r1 = bm.event.csrmv(heter_data, indices, indptr, events,
+ r1 = bm.sparse.csrmv(heter_data, indices, indptr, events,
shape=shape, transpose=transpose)
- r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float),
- shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r2))
-
- dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r3 = (events @ dense) if transpose else (dense @ events)
- self.assertTrue(bm.allclose(r1, r3))
+ r2 = taichi_csr_matvec(heter_data, indices, indptr, events,
+ shape=shape, transpose=transpose)
- r4 = bm.event.csrmv(heter_data, indices, indptr, events.astype(float),
- shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r4))
+ assert (bm.allclose(r1, r2))
bm.clear_buffer_memory()
- @parameterized.named_parameters(
- dict(
- testcase_name=f"transpose={transpose}, shape={shape}",
- shape=shape,
- transpose=transpose,
- )
- for transpose in [True, False]
- for shape in [(100, 200),
- (200, 200),
- (200, 100),
- (10, 1000),
- (2, 10000),
- (1000, 10),
- (100000, 2)]
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000)]
)
- def test_heter_vamp(self, shape, transpose):
+ def test_heter_vmap(self, shape, transpose):
print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}')
- rng = bm.random.RandomState()
+ rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
# vmap 'data'
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events,
+ f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events,
+ shape=shape, transpose=transpose))
+ f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
- f2 = jax.vmap(
- partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float),
- shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, indices.shape[0])))
self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))
# vmap 'events'
data = bm.as_jax(rng.random(indices.shape))
- f3 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr,
+ f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr,
shape=shape, transpose=transpose))
- f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr,
+ f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
- self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float))))
+ self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))
# vmap 'data' and 'events'
- f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee,
+ f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee,
shape=shape, transpose=transpose))
- f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee,
- shape=shape, transpose=transpose))
+ f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee,
+ shape=shape, transpose=transpose))
vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0])))
vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
- f6(vmap_data1, vmap_data2.astype(float))))
+ f6(vmap_data1, vmap_data2)))
bm.clear_buffer_memory()
- @parameterized.named_parameters(
- dict(testcase_name=f'transpose={transpose},shape={shape}',
- shape=shape,
- transpose=transpose,
- )
- for transpose in [True, False]
- for shape in [(100, 200),
- (200, 200),
- (200, 100),
- (10, 1000),
- (2, 10000),
- (1000, 10),
- (100000, 2)]
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000)]
)
def test_heter_grad(self, shape, transpose):
print(f'test_heter_grad: shape = {shape}, transpose = {transpose}')
- rng = bm.random.RandomState()
+ rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
@@ -295,27 +217,24 @@ def test_heter_grad(self, shape, transpose):
# grad 'data'
data = bm.as_jax(rng.random(indices.shape))
- r1 = jax.grad(sum_op(bm.event.csrmv))(
+ r1 = jax.grad(sum_op(bm.sparse.csrmv))(
+ data, indices, indptr, events, shape=shape, transpose=transpose)
+ r2 = jax.grad(sum_op(taichi_csr_matvec))(
data, indices, indptr, events, shape=shape, transpose=transpose)
- r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))(
- data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))
- dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape)
- r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else
- (a @ events))))(dense_data)
- rows, cols = bm.sparse.csr_to_coo(indices, indptr)
- r3 = r3[rows, cols]
- self.assertTrue(bm.allclose(r1, r3))
-
# grad 'events'
- r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(
+ r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
+ data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)(
+ data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r3, r4))
+
+ r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)(
+ r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else
- (dense_data @ e))))(events.astype(float))
- self.assertTrue(bm.allclose(r4, r5))
- self.assertTrue(bm.allclose(r4, r6))
+ self.assertTrue(bm.allclose(r5[0], r6[0]))
+ self.assertTrue(bm.allclose(r5[1], r6[1]))
bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py b/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py
deleted file mode 100644
index a5b8df152..000000000
--- a/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-import jax
-import pytest
-
-import test_event_csrmv
-
-if jax.default_backend() != 'gpu':
- pytest.skip("No gpu available.", allow_module_level=True)
-
-
-class Test_event_csr_matvec_GPU(test_event_csrmv.Test_event_csr_matvec):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs, platform='gpu')
diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py
new file mode 100644
index 000000000..31a6527a2
--- /dev/null
+++ b/brainpy/_src/math/event/tests/test_event_csrmv_old.py
@@ -0,0 +1,324 @@
+# -*- coding: utf-8 -*-
+
+
+from functools import partial
+
+import jax
+from absl.testing import parameterized
+
+import brainpy as bp
+import brainpy.math as bm
+import platform
+
+import pytest
+pytest.skip('Old implementation.', allow_module_level=True)
+
+is_manual_test = False
+# if platform.system() == 'Windows' and not is_manual_test:
+# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
+
+brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib')
+taichi_csr_matvec = partial(bm.event.csrmv, method='taichi')
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)
+ return r.sum()
+
+ return func
+
+
+class Test_event_csr_matvec(parameterized.TestCase):
+ def __init__(self, *args, platform='cpu', **kwargs):
+ super(Test_event_csr_matvec, self).__init__(*args, **kwargs)
+ bm.set_platform(platform)
+ print()
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}',
+ transpose=transpose,
+ shape=shape,
+ homo_data=homo_data,
+ )
+ for transpose in [True, False]
+ for shape in [(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000),
+ (2, 10000),
+ (1000, 10),
+ (10000, 2)]
+ for homo_data in [-1., 0., 1.]
+ )
+ def test_homo(self, shape, transpose, homo_data):
+ print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
+
+ rng = bm.random.RandomState()
+ indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ heter_data = bm.ones(indices.shape) * homo_data
+
+ r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
+ r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r3))
+
+ dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ r4 = (events @ dense) if transpose else (dense @ events)
+ self.assertTrue(bm.allclose(r1, r4))
+
+ r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r5))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}',
+ transpose=transpose,
+ shape=shape,
+ homo_data=homo_data,
+ )
+ for transpose in [True, False]
+ for shape in [(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000),
+ (2, 10000),
+ (1000, 10),
+ (100000, 2)]
+ for homo_data in [-1., 0., 1.]
+ )
+ def test_homo_vmap(self, shape, transpose, homo_data):
+ print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
+
+ rng = bm.random.RandomState()
+ indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
+
+ # vmap 'data'
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events,
+ shape=shape, transpose=transpose))
+ f2 = jax.vmap(
+ partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float),
+ shape=shape, transpose=transpose))
+ vmap_data = bm.as_jax([homo_data] * 10)
+ self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))
+
+ # vmap 'events'
+ f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr,
+ shape=shape, transpose=transpose))
+ f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr,
+ shape=shape, transpose=transpose))
+ vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
+ self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float))))
+
+ # vmap 'data' and 'events'
+ f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose))
+ f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose,
+ method='cusparse'))
+ vmap_data1 = bm.as_jax([homo_data] * 10)
+ vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
+ self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
+ f6(vmap_data1, vmap_data2.astype(float))))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}',
+ homo_data=homo_data,
+ shape=shape,
+ transpose=transpose,
+ )
+ for transpose in [True, False]
+ for shape in [(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000),
+ (2, 10000),
+ (1000, 10),
+ (100000, 2)]
+ for homo_data in [-1., 0., 1.]
+ )
+ def test_homo_grad(self, shape, transpose, homo_data):
+ print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
+
+ rng = bm.random.RandomState()
+ indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)
+
+ # grad 'data'
+ r1 = jax.grad(sum_op(brainpylib_csr_matvec))(
+ homo_data, indices, indptr, events, shape=shape, transpose=transpose)
+ r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))(
+ homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r2))
+ r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else
+ ((dense_conn * a) @ events))))(homo_data)
+ self.assertTrue(bm.allclose(r1, r3))
+
+ # grad 'events'
+ r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)(
+ homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)(
+ homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else
+ ((dense_conn * homo_data) @ e))))(events.astype(float))
+ self.assertTrue(bm.allclose(r4, r5))
+ self.assertTrue(bm.allclose(r4, r6))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name=f'transpose={transpose}, shape={shape}',
+ shape=shape,
+ transpose=transpose,
+ )
+ for transpose in [True, False]
+ for shape in [(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000),
+ (2, 10000),
+ (1000, 10),
+ (10000, 2)]
+ )
+ def test_heter(self, shape, transpose):
+ print(f'test_heter: shape = {shape}, transpose = {transpose}')
+
+ rng = bm.random.RandomState()
+ indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ heter_data = bm.as_jax(rng.random(indices.shape))
+
+ r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events,
+ shape=shape, transpose=transpose)
+ r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float),
+ shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ r3 = (events @ dense) if transpose else (dense @ events)
+ self.assertTrue(bm.allclose(r1, r3))
+
+ r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float),
+ shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r4))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name=f"transpose={transpose}, shape={shape}",
+ shape=shape,
+ transpose=transpose,
+ )
+ for transpose in [True, False]
+ for shape in [(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000),
+ (2, 10000),
+ (1000, 10),
+ (100000, 2)]
+ )
+ def test_heter_vmap(self, shape, transpose):
+ print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}')
+
+ rng = bm.random.RandomState()
+ indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+
+ # vmap 'data'
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events,
+ shape=shape, transpose=transpose))
+ f2 = jax.vmap(
+ partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float),
+ shape=shape, transpose=transpose))
+ vmap_data = bm.as_jax(rng.random((10, indices.shape[0])))
+ self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))
+
+ # vmap 'events'
+ data = bm.as_jax(rng.random(indices.shape))
+ f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr,
+ shape=shape, transpose=transpose))
+ f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr,
+ shape=shape, transpose=transpose))
+ vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
+ self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float))))
+
+ # vmap 'data' and 'events'
+ f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee,
+ shape=shape, transpose=transpose))
+ f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee,
+ shape=shape, transpose=transpose))
+ vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0])))
+ vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
+ self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
+ f6(vmap_data1, vmap_data2.astype(float))))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'transpose={transpose},shape={shape}',
+ shape=shape,
+ transpose=transpose,
+ )
+ for transpose in [True, False]
+ for shape in [(100, 200),
+ (200, 200),
+ (200, 100),
+ (10, 1000),
+ (2, 10000),
+ (1000, 10),
+ (100000, 2)]
+ )
+ def test_heter_grad(self, shape, transpose):
+ print(f'test_heter_grad: shape = {shape}, transpose = {transpose}')
+
+ rng = bm.random.RandomState()
+ indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)
+
+ # grad 'data'
+ data = bm.as_jax(rng.random(indices.shape))
+ r1 = jax.grad(sum_op(brainpylib_csr_matvec))(
+ data, indices, indptr, events, shape=shape, transpose=transpose)
+ r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))(
+ data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape)
+ r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else
+ (a @ events))))(dense_data)
+ rows, cols = bm.sparse.csr_to_coo(indices, indptr)
+ r3 = r3[rows, cols]
+ self.assertTrue(bm.allclose(r1, r3))
+
+ # grad 'events'
+ r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)(
+ data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)(
+ data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
+ r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else
+ (dense_data @ e))))(events.astype(float))
+ self.assertTrue(bm.allclose(r4, r5))
+ self.assertTrue(bm.allclose(r4, r6))
+
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/interoperability.py b/brainpy/_src/math/interoperability.py
index 766d4f8e1..948538371 100644
--- a/brainpy/_src/math/interoperability.py
+++ b/brainpy/_src/math/interoperability.py
@@ -8,6 +8,9 @@
__all__ = [
'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable',
+ 'from_numpy',
+
+ 'is_bp_array'
]
@@ -15,6 +18,12 @@ def _as_jax_array_(obj):
return obj.value if isinstance(obj, Array) else obj
+def is_bp_array(x):
+ """Check if the input is a ``brainpy.math.Array``.
+ """
+ return isinstance(x, Array)
+
+
def as_device_array(tensor, dtype=None):
"""Convert the input to a ``jax.numpy.DeviceArray``.
@@ -93,3 +102,8 @@ def as_variable(tensor, dtype=None):
"""
from .object_transform.variables import Variable
return Variable(tensor, dtype=dtype)
+
+
+def from_numpy(arr, dtype=None):
+ return as_ndarray(arr, dtype=dtype)
+
diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py
index 718de03d8..a79cdc982 100644
--- a/brainpy/_src/math/jitconn/__init__.py
+++ b/brainpy/_src/math/jitconn/__init__.py
@@ -1,3 +1,3 @@
from ._matvec import *
-from ._event_matvec import *
+from ._event_matvec import *
\ No newline at end of file
diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py
index d739919f7..3671755a9 100644
--- a/brainpy/_src/math/jitconn/_event_matvec.py
+++ b/brainpy/_src/math/jitconn/_event_matvec.py
@@ -10,18 +10,29 @@
from jax.interpreters import xla, ad
from jax.lib import xla_client
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops
+from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p,
mv_prob_uniform_p,
mv_prob_normal_p,
mv_prob_homo,
mv_prob_uniform,
- mv_prob_normal)
+ mv_prob_normal,
+ _general_checking,
+ raw_mv_prob_homo,
+ raw_mv_prob_uniform,
+ raw_mv_prob_normal,
+ _mv_prob_homo_transpose,
+ _mv_prob_uniform_transpose,
+ _mv_prob_normal_transpose,
+ _reverse)
from brainpy._src.math.ndarray import _get_dtype
-from brainpy._src.math.op_register import register_general_batching
+from brainpy._src.math.op_register import register_general_batching, XLACustomOp
+from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal)
from brainpy.errors import GPUOperatorNotFound
+ti = import_taichi()
+
__all__ = [
'event_mv_prob_homo',
'event_mv_prob_uniform',
@@ -38,6 +49,58 @@ def event_mv_prob_homo(
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
+) -> jax.Array:
+ return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__
+
+
+def event_mv_prob_uniform(
+ events: jax.Array,
+ w_low: float,
+ w_high: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__
+
+
+def event_mv_prob_normal(
+ events: jax.Array,
+ w_mu: float,
+ w_sigma: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+### BRAINPYLIB ###
+
+def event_mv_prob_homo_brainpylib(
+ events: jax.Array,
+ weight: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
) -> jax.Array:
events = as_jax(events)
weight = jnp.atleast_1d(as_jax(weight))
@@ -57,10 +120,10 @@ def event_mv_prob_homo(
return r
-event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__
+event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__
-def event_mv_prob_uniform(
+def event_mv_prob_uniform_brainpylib(
events: jax.Array,
w_low: float,
w_high: float,
@@ -90,10 +153,10 @@ def event_mv_prob_uniform(
outdim_parallel=outdim_parallel)[0]
-event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__
+event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__
-def event_mv_prob_normal(
+def event_mv_prob_normal_brainpylib(
events: jax.Array,
w_mu: float,
w_sigma: float,
@@ -123,7 +186,7 @@ def event_mv_prob_normal(
outdim_parallel=outdim_parallel)[0]
-event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__
+event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__
def _event_matvec_prob_homo_abstract(
@@ -295,8 +358,8 @@ def _event_matvec_prob_homo_transpose(
event_mv_prob_homo_p.multiple_results = True
event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract)
event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p))
-xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation
-xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation
+# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation
+# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation
ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp
ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose
register_general_batching(event_mv_prob_homo_p)
@@ -466,8 +529,8 @@ def _event_matvec_prob_uniform_transpose(
event_mv_prob_uniform_p.multiple_results = True
event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract)
event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p))
-xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation
-xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation
+# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation
+# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation
register_general_batching(event_mv_prob_uniform_p)
ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp
ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose
@@ -660,8 +723,1266 @@ def _event_matvec_prob_normal_transpose(
event_mv_prob_normal_p.multiple_results = True
event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract)
event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p))
-xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation
-xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation
+# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation
+# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation
register_general_batching(event_mv_prob_normal_p)
ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp
ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose
+
+
+### TAICHI ###
+
+def event_mv_prob_homo_taichi(
+ events: jax.Array,
+ weight: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ events: Array, ndarray
+ The events.
+ weight: float
+ The value of the random matrix.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ events = as_jax(events)
+ if isinstance(weight, float): weight = as_jax(weight)
+ weight = jnp.atleast_1d(as_jax(weight))
+ conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
+ conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
+ if seed is None:
+ with jax.ensure_compile_time_eval():
+ seed = np.random.randint(0, int(1e8), 1)
+ seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
+ return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)[0]
+
+
+def event_mv_prob_uniform_taichi(
+ events: jax.Array,
+ w_low: float,
+ w_high: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a uniform distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ events: Array, ndarray
+ The events.
+ w_low: float
+ Lower boundary of the output interval.
+ w_high: float
+ Upper boundary of the output interval.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ events = as_jax(events)
+ if isinstance(w_low, float): w_low = as_jax(w_low)
+ if isinstance(w_high, float): w_high = as_jax(w_high)
+ w_low = jnp.atleast_1d(as_jax(w_low))
+ w_high = jnp.atleast_1d(as_jax(w_high))
+ conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
+ conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
+ if seed is None:
+ with jax.ensure_compile_time_eval():
+ seed = np.random.randint(0, int(1e8), 1)
+ seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
+ return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)[0]
+
+
+def event_mv_prob_normal_taichi(
+ events: jax.Array,
+ w_mu: float,
+ w_sigma: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a normal distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ events: Array, ndarray
+ The events.
+ w_mu: float
+ Mean (centre) of the distribution.
+ w_sigma: float
+ Standard deviation (spread or “width”) of the distribution. Must be non-negative.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ events = as_jax(events)
+ if isinstance(w_mu, float): w_mu = as_jax(w_mu)
+ if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma)
+ w_mu = jnp.atleast_1d(as_jax(w_mu))
+ w_sigma = jnp.atleast_1d(as_jax(w_sigma))
+ conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
+ conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
+ if seed is None:
+ with jax.ensure_compile_time_eval():
+ seed = np.random.randint(0, int(1e8), 1)
+ seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
+ return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)[0]
+
+
+# -------------
+# CPU function
+# -------------
+# For each non-zero event value, it generates a random key using a
+# function lfsr88_key and then uses this key to compute random integers
+# and update the out array based on the computed indices and weight.
+#
+# The function is likely designed to be parallelized.
+
+
+@ti.kernel
+def _event_mv_prob_homo_bool_cpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ if events[i_col]:
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ out[i_row] += weight0
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_homo_outdim_parallel_bool_cpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ if events[i_col]:
+ r += weight0
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r
+
+
+# -------------
+# GPU function
+# -------------
+# Contrary to the CPU functions, for each column,
+# this function will 32 threads (one warp) to make
+# the just-in-time random generation parallelized.
+
+
+@ti.kernel
+def _event_mv_prob_homo_bool_gpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ if events[i_col]:
+ index = i & 31
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ out[i_row] += weight0
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_homo_outdim_parallel_bool_gpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.u32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ index = i & 31
+ i_col = step * index - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ r += weight0 * events[i_col] # TODO: speed comparison without if else
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+def _reverse(shape):
+ return shape[::-1]
+
+
+# -------------
+# CPU function
+# -------------
+# For each non-zero event value, it generates a random key using a
+# function lfsr88_key and then uses this key to compute random integers
+# and update the out array based on the computed indices and weight.
+#
+# The function is likely designed to be parallelized.
+
+
+@ti.kernel
+def _event_mv_prob_homo_cpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ if events[i_col] != 0.:
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ out[i_row] += weight0
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_homo_outdim_parallel_cpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ if events[i_col] != 0.:
+ r += weight0
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r # TODO: warp-level reduction
+
+
+# -------------
+# GPU function
+# -------------
+# Contrary to the CPU functions, for each column,
+# this function will 32 threads (one warp) to make
+# the just-in-time random generation parallelized.
+
+
+@ti.kernel
+def _event_mv_prob_homo_gpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ if events[i_col] != 0.:
+ index = i & 31
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ out[i_row] += weight0
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_homo_outdim_parallel_gpu(
+ events: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ index = i & 31
+ i_col = step * index - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ r += weight0 * events[i_col] # TODO: speed comparison with if else
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+def _event_mv_prob_homo_jvp_events(
+ evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_homo(evt_dot, weight, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _event_mv_prob_homo_jvp_weight(
+ w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_homo(events, w_dot, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
+ assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64]
+ return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights)
+
+
+def raw_event_mv_prob_homo(
+ events: jax.Array,
+ weight: jax.Array, # vector with size 1
+ conn_len: jax.Array, # vector with size 1
+ seed: jax.Array, # vector with size 1
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight)
+
+ if outdim_parallel:
+ if events.dtype == jnp.bool_:
+ prim = _event_mv_prob_homo_outdim_parallel_bool_p
+ else:
+ prim = _event_mv_prob_homo_outdim_parallel_p
+ else:
+ if events.dtype == jnp.bool_:
+ prim = _event_mv_prob_homo_bool_p
+ else:
+ prim = _event_mv_prob_homo_p
+
+ return prim(events,
+ weight,
+ conn_len,
+ seed,
+ outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)],
+ shape=mat_shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_event_mv_prob_homo_jvp_events,
+ _event_mv_prob_homo_jvp_weight,
+ None,
+ None)
+ prim.def_transpose_rule(_mv_prob_homo_transpose)
+ return prim
+
+
+# outdim_parallel = True, events.dtype = jnp.bool_
+_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim(
+ cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu,
+ gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu
+)
+
+# outdim_parallel = False, events.dtype = jnp.bool_
+_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim(
+ cpu_kernel=_event_mv_prob_homo_bool_cpu,
+ gpu_kernel=_event_mv_prob_homo_bool_gpu
+)
+
+# outdim_parallel = True, events.dtype != jnp.bool_
+_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim(
+ cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu,
+ gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu
+)
+
+# outdim_parallel = False, events.dtype != jnp.bool_
+_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim(
+ cpu_kernel=_event_mv_prob_homo_cpu,
+ gpu_kernel=_event_mv_prob_homo_gpu
+)
+
+
+@ti.kernel
+def _event_mv_prob_uniform_bool_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ if events[i_col]:
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_uniform_outdim_parallel_bool_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ if events[i_col]:
+ r += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r
+
+
+@ti.kernel
+def _event_mv_prob_uniform_bool_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ if events[i_col]:
+ index = i & 31
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_uniform_outdim_parallel_bool_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.u32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ index = i & 31
+ i_col = step * index - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ r += row_v * events[i_col] # TODO: speed comparison without if else
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+@ti.kernel
+def _event_mv_prob_uniform_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ if events[i_col] != 0.:
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_uniform_outdim_parallel_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ if events[i_col] != 0.:
+ r += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r # TODO: warp-level reduction
+
+
+@ti.kernel
+def _event_mv_prob_uniform_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ if events[i_col] != 0.:
+ index = i & 31
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_uniform_outdim_parallel_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ index = i & 31
+ i_col = step * index - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ r += row_v * events[i_col] # TODO: speed comparison with if else
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+def _event_mv_prob_uniform_jvp_events(
+ evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _event_mv_prob_uniform_jvp_w_low(
+ w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _event_mv_prob_uniform_jvp_w_high(
+ w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def raw_event_mv_prob_uniform(
+ events: jax.Array,
+ w_low: jax.Array, # vector with size 1
+ w_high: jax.Array, # vector with size 1
+ conn_len: jax.Array, # vector with size 1
+ seed: jax.Array, # vector with size 1
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high)
+
+ if outdim_parallel:
+ if events.dtype == jnp.bool_:
+ prim = _event_mv_prob_uniform_outdim_parallel_bool_p
+ else:
+ prim = _event_mv_prob_uniform_outdim_parallel_p
+ else:
+ if events.dtype == jnp.bool_:
+ prim = _event_mv_prob_uniform_bool_p
+ else:
+ prim = _event_mv_prob_uniform_p
+
+ return prim(events,
+ w_low,
+ w_high,
+ conn_len,
+ seed,
+ outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)],
+ shape=mat_shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_event_mv_prob_uniform_jvp_events,
+ _event_mv_prob_uniform_jvp_w_low,
+ _event_mv_prob_uniform_jvp_w_high,
+ None,
+ None)
+ prim.def_transpose_rule(_mv_prob_uniform_transpose)
+ return prim
+
+
+# outdim_parallel = True, events.dtype = jnp.bool_
+_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim(
+ cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu,
+ gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu
+)
+
+# outdim_parallel = False, events.dtype = jnp.bool_
+_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim(
+ cpu_kernel=_event_mv_prob_uniform_bool_cpu,
+ gpu_kernel=_event_mv_prob_uniform_bool_gpu
+)
+
+# outdim_parallel = True, events.dtype != jnp.bool_
+_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim(
+ cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu,
+ gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu
+)
+
+# outdim_parallel = False, events.dtype != jnp.bool_
+_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim(
+ cpu_kernel=_event_mv_prob_uniform_cpu,
+ gpu_kernel=_event_mv_prob_uniform_gpu
+)
+
+
+@ti.kernel
+def _event_mv_prob_normal_bool_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ if events[i_col]:
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_normal_outdim_parallel_bool_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ if events[i_col]:
+ r += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r
+
+
+@ti.kernel
+def _event_mv_prob_normal_bool_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ if events[i_col]:
+ index = i & 31
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_normal_outdim_parallel_bool_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.u32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ index = i & 31
+ i_col = step * index - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ r += row_v * events[i_col] # TODO: speed comparison without if else
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+@ti.kernel
+def _event_mv_prob_normal_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ if events[i_col] != 0.:
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_normal_outdim_parallel_cpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ if events[i_col] != 0.:
+ r += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r
+
+
+@ti.kernel
+def _event_mv_prob_normal_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ if events[i_col] != 0.:
+ index = i & 31
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ out[i_row] += row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _event_mv_prob_normal_outdim_parallel_gpu(
+ events: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = events.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ index = i & 31
+ i_col = step * index - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ r += row_v * events[i_col] # TODO: speed comparison with if else
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+def _event_mv_prob_normal_jvp_events(
+ evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _event_mv_prob_normal_jvp_w_mu(
+ w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _event_mv_prob_normal_jvp_w_sigma(
+ w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def raw_event_mv_prob_normal(
+ events: jax.Array,
+ w_mu: jax.Array, # vector with size 1
+ w_sigma: jax.Array, # vector with size 1
+ conn_len: jax.Array, # vector with size 1
+ seed: jax.Array, # vector with size 1
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma)
+
+ if outdim_parallel:
+ if events.dtype == jnp.bool_:
+ prim = _event_mv_prob_normal_outdim_parallel_bool_p
+ else:
+ prim = _event_mv_prob_normal_outdim_parallel_p
+ else:
+ if events.dtype == jnp.bool_:
+ prim = _event_mv_prob_normal_bool_p
+ else:
+ prim = _event_mv_prob_normal_p
+
+ return prim(events,
+ w_mu,
+ w_sigma,
+ conn_len,
+ seed,
+ outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)],
+ shape=mat_shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_event_mv_prob_normal_jvp_events,
+ _event_mv_prob_normal_jvp_w_mu,
+ _event_mv_prob_normal_jvp_w_sigma,
+ None,
+ None)
+ prim.def_transpose_rule(_mv_prob_normal_transpose)
+ return prim
+
+
+# outdim_parallel = True, events.dtype = jnp.bool_
+_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim(
+ cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu,
+ gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu
+)
+
+# outdim_parallel = False, events.dtype = jnp.bool_
+_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim(
+ cpu_kernel=_event_mv_prob_normal_bool_cpu,
+ gpu_kernel=_event_mv_prob_normal_bool_gpu
+)
+
+# outdim_parallel = True, events.dtype != jnp.bool_
+_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim(
+ cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu,
+ gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu
+)
+
+# outdim_parallel = False, events.dtype != jnp.bool_
+_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim(
+ cpu_kernel=_event_mv_prob_normal_cpu,
+ gpu_kernel=_event_mv_prob_normal_gpu
+)
diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py
index cad95924d..0caa9c996 100644
--- a/brainpy/_src/math/jitconn/_matvec.py
+++ b/brainpy/_src/math/jitconn/_matvec.py
@@ -11,12 +11,15 @@
from jax.interpreters import xla, ad
from jax.lib import xla_client
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops
+from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array, _get_dtype
-from brainpy._src.math.op_register import register_general_batching
+from brainpy._src.math.op_register import register_general_batching, XLACustomOp
+from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal)
from brainpy.errors import GPUOperatorNotFound
+ti = import_taichi()
+
__all__ = [
'mv_prob_homo',
'mv_prob_uniform',
@@ -49,6 +52,200 @@ def mv_prob_homo(
When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ weight: float
+ The value of the random matrix.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def mv_prob_uniform(
+ vector: jax.Array,
+ w_low: float,
+ w_high: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a uniform distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ w_low: float
+ Lower boundary of the output interval.
+ w_high: float
+ Upper boundary of the output interval.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def mv_prob_normal(
+ vector: jax.Array,
+ w_mu: float,
+ w_sigma: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a normal distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ w_mu: float
+ Mean (centre) of the distribution.
+ w_sigma: float
+ Standard deviation (spread or “width”) of the distribution. Must be non-negative.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+### BRAINYPLIB ###
+
+def mv_prob_homo_brainpylib(
+ vector: Union[Array, jax.Array],
+ weight: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
.. note::
Note that the just-in-time generated :math:`M` (`transpose=False`) is
@@ -100,7 +297,7 @@ def mv_prob_homo(
)[0]
-def mv_prob_uniform(
+def mv_prob_uniform_brainpylib(
vector: jax.Array,
w_low: float,
w_high: float,
@@ -180,7 +377,7 @@ def mv_prob_uniform(
outdim_parallel=outdim_parallel)[0]
-def mv_prob_normal(
+def mv_prob_normal_brainpylib(
vector: jax.Array,
w_mu: float,
w_sigma: float,
@@ -443,8 +640,8 @@ def _matvec_prob_homo_transpose(
mv_prob_homo_p.multiple_results = True
mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract)
mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p))
-xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation
-xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation
+# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation
+# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation
register_general_batching(mv_prob_homo_p)
ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp
ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose
@@ -626,8 +823,8 @@ def _matvec_prob_uniform_transpose(
mv_prob_uniform_p.multiple_results = True
mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract)
mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p))
-xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation
-xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation
+# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation
+# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation
register_general_batching(mv_prob_uniform_p)
ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp
ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose
@@ -812,8 +1009,897 @@ def _matvec_prob_normal_transpose(
mv_prob_normal_p.multiple_results = True
mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract)
mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p))
-xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation
-xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation
+# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation
+# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation
register_general_batching(mv_prob_normal_p)
ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp
ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose
+
+
+### TAICHI ###
+def mv_prob_homo_taichi(
+ vector: Union[Array, jax.Array],
+ weight: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of
+ the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``.
+
+ Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same
+ of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ weight: float
+ The value of the random matrix.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ vector = as_jax(vector)
+ if isinstance(weight, float):
+ weight = as_jax(weight, dtype=vector.dtype)
+ weight = jnp.atleast_1d(as_jax(weight))
+ conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
+ clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
+ if seed is None:
+ with jax.ensure_compile_time_eval():
+ seed = np.random.randint(0, int(1e8), 1)
+ seed = jnp.asarray(seed, dtype=jnp.uint32)
+ seed = jnp.atleast_1d(seed)
+ return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)[0]
+
+
+def mv_prob_uniform_taichi(
+ vector: jax.Array,
+ w_low: float,
+ w_high: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a uniform distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ w_low: float
+ Lower boundary of the output interval.
+ w_high: float
+ Upper boundary of the output interval.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ vector = as_jax(vector)
+ if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype)
+ if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype)
+ w_low = jnp.atleast_1d(as_jax(w_low))
+ w_high = jnp.atleast_1d(as_jax(w_high))
+ conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
+ conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
+ if seed is None:
+ with jax.ensure_compile_time_eval():
+ seed = np.random.randint(0, int(1e8), 1)
+ seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
+ return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)[0]
+
+
+def mv_prob_normal_taichi(
+ vector: jax.Array,
+ w_mu: float,
+ w_sigma: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a normal distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ w_mu: float
+ Mean (centre) of the distribution.
+ w_sigma: float
+ Standard deviation (spread or “width”) of the distribution. Must be non-negative.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ vector = as_jax(vector)
+ if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype)
+ if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype)
+ w_mu = jnp.atleast_1d(as_jax(w_mu))
+ w_sigma = jnp.atleast_1d(as_jax(w_sigma))
+ conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
+ conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
+ if seed is None:
+ with jax.ensure_compile_time_eval():
+ seed = np.random.randint(0, int(1e8), 1)
+ seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
+ return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)[0]
+
+
+def _reverse(shape):
+ return shape[::-1]
+
+
+@ti.kernel
+def _mv_prob_homo_cpu(
+ vector: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ v = vector[i_col] * weight0
+ while i_row < num_row:
+ out[i_row] += v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _mv_prob_homo_outdim_parallel_cpu(
+ vector: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ r += vector[i_col]
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r * weight0
+
+
+@ti.kernel
+def _mv_prob_homo_gpu(
+ vector: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ index = i & 31
+ col_v = vector[i_col]
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ out[i_row] += weight0 * col_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _mv_prob_homo_outdim_parallel_gpu(
+ vector: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ weight0 = weight[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.u32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ i_thread = i & 31
+ i_col = step * i_thread - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ r += vector[i_col]
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += weight0 * r # TODO: warp-level reduction
+
+
+def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_homo_transpose(
+ ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ if ad.is_undefined_primal(vector):
+ if type(ct) is ad.Zero:
+ return ad.Zero(vector), weight, clen, seed
+ else:
+ dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape,
+ transpose=not transpose, outdim_parallel=not outdim_parallel)[0]
+ return dv, weight, clen, seed
+ elif ad.is_undefined_primal(weight):
+ if type(ct) is ad.Zero:
+ return vector, ad.Zero(weight), clen, seed
+ else:
+ row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed,
+ shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0]
+ dw = jnp.sum(row * vector, keepdims=True)
+ return vector, dw, clen, seed
+ else:
+ assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.'
+ assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.'
+
+
+def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
+ if vector.ndim != 1:
+ raise ValueError('vector should be a 1D vector.')
+ if len(shape) != 2:
+ raise ValueError('shape should be a length-2 tuple.')
+ if seed.ndim != 1:
+ raise ValueError('seed must be a 1D scalar.')
+ if clen.ndim != 1:
+ raise ValueError('conn_prob must be a 1D scalar.')
+
+ assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64]
+ assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64]
+
+ for weight in weights:
+ if weight.ndim != 1:
+ raise ValueError('weight must be a 1D scalar.')
+ assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.'
+
+ if not isinstance(outdim_parallel, bool):
+ raise ValueError('outdim_parallel must be boolean value.')
+ if not isinstance(transpose, bool):
+ raise ValueError('transpose must be boolean value.')
+
+ if transpose:
+ out_shape = (shape[1],)
+ if vector.shape[0] != shape[0]:
+ raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.')
+ shape = _reverse(shape)
+ else:
+ if vector.shape[0] != shape[1]:
+ raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).')
+ out_shape = (shape[0],)
+
+ return shape, out_shape
+
+
+def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
+ assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64]
+ return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights)
+
+
+def raw_mv_prob_homo(
+ vector: jax.Array,
+ weight: jax.Array, # vector with size 1
+ clen: jax.Array, # vector with size 1
+ seed: jax.Array, # vector with size 1
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight)
+
+ if outdim_parallel:
+ prim = _mv_prob_homo_outdim_parallel_p
+ else:
+ prim = _mv_prob_homo_p
+
+ return prim(vector,
+ weight,
+ clen,
+ seed,
+ outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)],
+ shape=mat_shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None)
+ prim.def_transpose_rule(_mv_prob_homo_transpose)
+ return prim
+
+
+# outdim_parallel = True
+_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu,
+ gpu_kernel=_mv_prob_homo_outdim_parallel_gpu)
+
+# outdim_parallel = False
+_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu,
+ gpu_kernel=_mv_prob_homo_gpu)
+
+
+@ti.kernel
+def _mv_prob_uniform_cpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ col_v = vector[i_col]
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ key, raw_v = lfsr88_uniform(key, w_min0, w_max0)
+ out[i_row] += col_v * raw_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _mv_prob_uniform_outdim_parallel_cpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ key, raw_v = lfsr88_uniform(key, w_min0, w_max0)
+ r += vector[i_col] * raw_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r
+
+
+@ti.kernel
+def _mv_prob_uniform_gpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ index = i & 31
+ col_v = vector[i_col]
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ out[i_row] += row_v * col_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _mv_prob_uniform_outdim_parallel_gpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.u32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ i_thread = i & 31
+ i_col = step * i_thread - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ key, row_v = lfsr88_uniform(key, w_min0, w_max0)
+ r += vector[i_col] * row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *,
+ outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *,
+ outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *,
+ outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_uniform_transpose(
+ ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ if ad.is_undefined_primal(vector):
+ if type(ct) is ad.Zero:
+ return ad.Zero(vector), w_low, w_high, clen, seed
+ else:
+ dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape,
+ transpose=not transpose, outdim_parallel=not outdim_parallel)[0]
+ return dv, w_low, w_high, clen, seed
+ else:
+ assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.'
+ assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.'
+ assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.'
+ assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.'
+
+
+def raw_mv_prob_uniform(
+ vector: jax.Array,
+ w_low: jax.Array,
+ w_high: jax.Array,
+ conn_len: jax.Array,
+ seed: jax.Array,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high)
+
+ if outdim_parallel:
+ prim = _mv_prob_uniform_outdim_parallel_p
+ else:
+ prim = _mv_prob_uniform_p
+
+ return prim(vector,
+ w_low,
+ w_high,
+ conn_len,
+ seed,
+ outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)],
+ shape=mat_shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_mv_prob_uniform_jvp_vector,
+ _mv_prob_uniform_jvp_wlow,
+ _mv_prob_uniform_jvp_whigh,
+ None,
+ None)
+ prim.def_transpose_rule(_mv_prob_uniform_transpose)
+ return prim
+
+
+# outdim_parallel = True
+_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim(
+ cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu,
+ gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu
+)
+
+# outdim_parallel = False
+_mv_prob_uniform_p = _define_mv_prob_uniform_prim(
+ cpu_kernel=_mv_prob_uniform_cpu,
+ gpu_kernel=_mv_prob_uniform_gpu
+)
+
+
+@ti.kernel
+def _mv_prob_normal_cpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_col in range(num_col):
+ col_v = vector[i_col]
+ key = lfsr88_key(seed0 + i_col)
+ key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_row < num_row:
+ key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ out[i_row] += col_v * raw_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _mv_prob_normal_outdim_parallel_cpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+
+ for i_row in range(num_row):
+ r = 0.
+ key = lfsr88_key(seed0 + i_row)
+ key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
+ while i_col < num_col:
+ key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ r += vector[i_col] * raw_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] = r
+
+
+@ti.kernel
+def _mv_prob_normal_gpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_col * 32):
+ i_col = i >> 5
+ index = i & 31
+ col_v = vector[i_col]
+ i_row = step * index - 1
+ end = ti.min(i_row + step, num_row)
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+ while i_row < end:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ out[i_row] += row_v * col_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_row += inc
+
+
+@ti.kernel
+def _mv_prob_normal_outdim_parallel_gpu(
+ vector: ti.types.ndarray(ndim=1),
+ w_mu: ti.types.ndarray(ndim=1),
+ w_sigma: ti.types.ndarray(ndim=1),
+ clen: ti.types.ndarray(ndim=1),
+ seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)
+):
+ num_row = out.shape[0]
+ num_col = vector.shape[0]
+ w_mu0 = w_mu[0]
+ w_sigma0 = w_sigma[0]
+ clen0 = clen[0]
+ seed0 = seed[0]
+ step = ti.u32(ti.max((num_row + 1) >> 5, 1))
+
+ for i in range(num_row * 32):
+ i_row = i >> 5
+ i_thread = i & 31
+ i_col = step * i_thread - 1
+ end_col = ti.min(i_col + step, num_col)
+ r = 0.
+ key = lfsr88_key(seed0 + i)
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ while i_col < end_col:
+ key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
+ r += vector[i_col] * row_v
+ key, inc = lfsr88_random_integers(key, 1, clen0)
+ i_col += inc
+ out[i_row] += r # TODO: warp-level reduction
+
+
+def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel):
+ shape = _reverse(shape) if transpose else shape
+ return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def _mv_prob_normal_transpose(
+ ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
+):
+ shape = _reverse(shape) if transpose else shape
+ if ad.is_undefined_primal(vector):
+ if type(ct) is ad.Zero:
+ return ad.Zero(vector), w_mu, w_sigma, clen, seed
+ else:
+ dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape,
+ transpose=not transpose, outdim_parallel=not outdim_parallel)[0]
+ return dv, w_mu, w_sigma, clen, seed
+ else:
+ assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.'
+ assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.'
+ assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.'
+ assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.'
+
+
+def raw_mv_prob_normal(
+ vector: jax.Array,
+ w_mu: jax.Array,
+ w_sigma: jax.Array,
+ conn_len: jax.Array,
+ seed: jax.Array,
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ outdim_parallel: bool = True,
+) -> jax.Array:
+ mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma)
+
+ if outdim_parallel:
+ prim = _mv_prob_normal_outdim_parallel_p
+ else:
+ prim = _mv_prob_normal_p
+
+ return prim(vector,
+ w_mu,
+ w_sigma,
+ conn_len,
+ seed,
+ outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)],
+ shape=mat_shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_mv_prob_normal_jvp_vector,
+ _mv_prob_normal_jvp_w_mu,
+ _mv_prob_normal_jvp_w_sigma,
+ None,
+ None)
+ prim.def_transpose_rule(_mv_prob_normal_transpose)
+ return prim
+
+
+# outdim_parallel = True
+_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim(
+ cpu_kernel=_mv_prob_normal_outdim_parallel_cpu,
+ gpu_kernel=_mv_prob_normal_outdim_parallel_gpu
+)
+
+# outdim_parallel = False
+_mv_prob_normal_p = _define_mv_prob_normal_prim(
+ cpu_kernel=_mv_prob_normal_cpu,
+ gpu_kernel=_mv_prob_normal_gpu
+)
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py
new file mode 100644
index 000000000..21a246650
--- /dev/null
+++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py
@@ -0,0 +1,573 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('cpu')
+
+seed = 1234
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+ ]
+types = [
+ 'homo',
+ 'uniform',
+ 'normal'
+ ]
+transpose = [
+ True,
+ False
+ ]
+outdim_parallel = [
+ True,
+ False,
+ ]
+bool_event = [
+ True,
+ False
+ ]
+conn_prob = 0.05
+homo_data = 1.
+w_low = 0.
+w_high = 1.
+w_mu = 0.
+w_sigma = 0.1
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+print(bm.get_platform())
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_event_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.event_mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0]
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_event_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.event_mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0]
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.event_mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.event_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.event_mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.event_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+
+def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+
+ # groundtruth = bm.as_jax(events, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+
+def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event):
+ print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
+ if _type == 'homo':
+ return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event)
+ elif _type == 'uniform':
+ return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event)
+ elif _type == 'normal':
+ return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event)
+ else:
+ raise ValueError
+
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ for _bool_event in bool_event:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/jitconn_event_matvec_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ for _bool_event in bool_event:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/jitconn_event_matvec_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py
new file mode 100644
index 000000000..ff4f01afc
--- /dev/null
+++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py
@@ -0,0 +1,589 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('cpu')
+# bm.disable_gpu_memory_preallocation()
+
+seed = 1234
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+ ]
+types = [
+ 'homo',
+ 'uniform',
+ 'normal'
+ ]
+transpose = [
+ True,
+ False
+ ]
+outdim_parallel = [
+ True,
+ False,
+ ]
+bool_event = [
+ True,
+ False
+ ]
+conn_prob = 0.05
+homo_data = 1.
+w_low = 0.
+w_high = 1.
+w_mu = 0.
+w_sigma = 0.1
+
+print(bm.get_platform())
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)[0]
+ return r.sum()
+
+ return func
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_event_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r +=jax.grad(sum_op(bm.jitconn.event_mv_prob_homo_taichi), argnums=0)(
+ vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_event_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.event_mv_prob_homo), argnums=0)(
+ vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform_taichi), argnums=0)(
+ vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform), argnums=0)(
+ vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal_taichi), argnums=0)(
+ vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_event_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal), argnums=0)(
+ vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event):
+ print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
+ if _type == 'homo':
+ return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event)
+ elif _type == 'uniform':
+ return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event)
+ elif _type == 'normal':
+ return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event)
+ else:
+ raise ValueError
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
+
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ for _bool_event in bool_event:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/jitconn_event_matvec_grad_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ for _bool_event in bool_event:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/jitconn_event_matvec_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py
new file mode 100644
index 000000000..14a19aefb
--- /dev/null
+++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py
@@ -0,0 +1,560 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('gpu')
+
+seed = 1234
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+ ]
+types = [
+ 'homo',
+ 'uniform',
+ 'normal'
+ ]
+transpose = [
+ True,
+ False
+ ]
+outdim_parallel = [
+ True,
+ False,
+ ]
+bool_event = False
+conn_prob = 0.05
+homo_data = 1.
+w_low = 0.
+w_high = 1.
+w_mu = 0.
+w_sigma = 0.1
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+print(bm.get_platform())
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.jitconn.mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
+ return r
+
+def test_jitconn_matvec_homo(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec_normal(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+def test_jitconn_matvec(shape, _type, transpose, outdim_parallel):
+ print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
+ if _type == 'homo':
+ return test_jitconn_matvec_homo(shape, transpose, outdim_parallel)
+ elif _type == 'uniform':
+ return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel)
+ elif _type == 'normal':
+ return test_jitconn_matvec_normal(shape, transpose, outdim_parallel)
+ else:
+ raise ValueError
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/jitconn_matvec_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/jitconn_matvec_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py
new file mode 100644
index 000000000..165c9b19b
--- /dev/null
+++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py
@@ -0,0 +1,736 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('cpu')
+
+seed = 1234
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+ ]
+bool_event = False
+types = [
+ 'homo',
+ 'uniform',
+ 'normal'
+ ]
+transpose = [
+ True,
+ False
+ ]
+outdim_parallel = [
+ True,
+ False,
+ ]
+conn_prob = 0.05
+homo_data = 1.
+w_low = 0.
+w_high = 1.
+w_mu = 0.
+w_sigma = 0.1
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+print(bm.get_platform())
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)[0]
+ return r.sum()
+
+ return func
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)(
+ vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5, 6))
+def jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)(
+ vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)(
+ vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)(
+ vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)(
+ vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+@partial(jax.jit, static_argnums=(5, 6, 7))
+def jitconn_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)(
+ vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ return r
+
+def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ # time.sleep(2)
+
+ time0 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+ # time.sleep(2)
+
+ time2 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+ # time.sleep(2)
+
+ time4 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+ # time.sleep(2)
+
+ time6 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+
+ time12 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+ # time.sleep(2)
+
+ time14 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+ # time.sleep(2)
+
+ time16 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+ # time.sleep(2)
+
+ time18 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+ time20 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ brainpy_time1 = (time13 - time12) * 1000
+ brainpy_time2 = (time15 - time14) * 1000
+ brainpy_time3 = (time17 - time16) * 1000
+ brainpy_time4 = (time19 - time18) * 1000
+ brainpy_time5 = (time21 - time20) * 1000
+
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_2: ', taichi_aot_time2, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_4: ', taichi_aot_time4, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('brainpylib_cpu_1: ', brainpy_time1, 'ms')
+ print('brainpylib_cpu_2: ', brainpy_time2, 'ms')
+ print('brainpylib_cpu_3: ', brainpy_time3, 'ms')
+ print('brainpylib_cpu_4: ', brainpy_time4, 'ms')
+ print('brainpylib_cpu_5: ', brainpy_time5, 'ms')
+ # assert(jnp.allclose(result1[0], result2))
+
+ speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
+ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
+
+def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ # time.sleep(2)
+
+ time0 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time1 = time.time()
+ # time.sleep(2)
+
+ time2 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time3 = time.time()
+ # time.sleep(2)
+
+ time4 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time5 = time.time()
+ # time.sleep(2)
+
+ time6 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time7 = time.time()
+
+ time8 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time9 = time.time()
+
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+# print(result1[0])
+# print(result2)
+# print(groundtruth - result1[0])
+# print(groundtruth - result2)
+
+ # print(result1[0] - result2)
+ # print(bm.allclose(groundtruth, result1[0]))
+ # print(bm.allclose(groundtruth, result2))
+ # assert bm.allclose(result1[0], result2)
+
+ time12 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time13 = time.time()
+ # time.sleep(2)
+
+ time14 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time15 = time.time()
+ # time.sleep(2)
+
+ time16 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time17 = time.time()
+ # time.sleep(2)
+
+ time18 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time19 = time.time()
+
+ time20 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time21 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ brainpy_time1 = (time13 - time12) * 1000
+ brainpy_time2 = (time15 - time14) * 1000
+ brainpy_time3 = (time17 - time16) * 1000
+ brainpy_time4 = (time19 - time18) * 1000
+ brainpy_time5 = (time21 - time20) * 1000
+
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_2: ', taichi_aot_time2, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_4: ', taichi_aot_time4, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('brainpylib_cpu_1: ', brainpy_time1, 'ms')
+ print('brainpylib_cpu_2: ', brainpy_time2, 'ms')
+ print('brainpylib_cpu_3: ', brainpy_time3, 'ms')
+ print('brainpylib_cpu_4: ', brainpy_time4, 'ms')
+ print('brainpylib_cpu_5: ', brainpy_time5, 'ms')
+ # assert(jnp.allclose(result1[0], result2))
+
+ speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
+ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
+
+def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ # time.sleep(2)
+
+ time0 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time1 = time.time()
+ # time.sleep(2)
+
+ time2 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time3 = time.time()
+ # time.sleep(2)
+
+ time4 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time5 = time.time()
+ # time.sleep(2)
+
+ time6 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time7 = time.time()
+
+ time8 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time9 = time.time()
+
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+# print(result1[0])
+# print(result2)
+# print(groundtruth - result1[0])
+# print(groundtruth - result2)
+
+ # print(result1[0] - result2)
+ # print(bm.allclose(groundtruth, result1[0]))
+ # print(bm.allclose(groundtruth, result2))
+ # assert bm.allclose(result1[0], result2)
+
+ time12 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time13 = time.time()
+ # time.sleep(2)
+
+ time14 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time15 = time.time()
+ # time.sleep(2)
+
+ time16 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time17 = time.time()
+ # time.sleep(2)
+
+ time18 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time19 = time.time()
+
+ time20 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time21 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ brainpy_time1 = (time13 - time12) * 1000
+ brainpy_time2 = (time15 - time14) * 1000
+ brainpy_time3 = (time17 - time16) * 1000
+ brainpy_time4 = (time19 - time18) * 1000
+ brainpy_time5 = (time21 - time20) * 1000
+
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_2: ', taichi_aot_time2, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_4: ', taichi_aot_time4, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('brainpylib_cpu_1: ', brainpy_time1, 'ms')
+ print('brainpylib_cpu_2: ', brainpy_time2, 'ms')
+ print('brainpylib_cpu_3: ', brainpy_time3, 'ms')
+ print('brainpylib_cpu_4: ', brainpy_time4, 'ms')
+ print('brainpylib_cpu_5: ', brainpy_time5, 'ms')
+ # assert(jnp.allclose(result1[0], result2))
+
+ speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
+ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
+
+def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ # time.sleep(2)
+
+ time0 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time1 = time.time()
+ # time.sleep(2)
+
+ time2 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time3 = time.time()
+ # time.sleep(2)
+
+ time4 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time5 = time.time()
+ # time.sleep(2)
+
+ time6 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time9 = time.time()
+
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+# print(result1[0])
+# print(result2)
+# print(groundtruth - result1[0])
+# print(groundtruth - result2)
+
+ # print(result1[0] - result2)
+ # print(bm.allclose(groundtruth, result1[0]))
+ # print(bm.allclose(groundtruth, result2))
+ # assert bm.allclose(result1[0], result2)
+
+ time12 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time13 = time.time()
+ # time.sleep(2)
+
+ time14 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time15 = time.time()
+ # time.sleep(2)
+
+ time16 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time17 = time.time()
+ # time.sleep(2)
+
+ time18 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time19 = time.time()
+
+ time20 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
+ time21 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ brainpy_time1 = (time13 - time12) * 1000
+ brainpy_time2 = (time15 - time14) * 1000
+ brainpy_time3 = (time17 - time16) * 1000
+ brainpy_time4 = (time19 - time18) * 1000
+ brainpy_time5 = (time21 - time20) * 1000
+
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_2: ', taichi_aot_time2, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_4: ', taichi_aot_time4, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_2: ', brainpy_time2, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_4: ', brainpy_time4, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ # assert(jnp.allclose(result1[0], result2))
+
+ speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
+ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
+
+def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ # time.sleep(2)
+
+ time0 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time1 = time.time()
+ # time.sleep(2)
+
+ time2 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time3 = time.time()
+ # time.sleep(2)
+
+ time4 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time5 = time.time()
+ # time.sleep(2)
+
+ time6 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time7 = time.time()
+
+ time8 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time9 = time.time()
+
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+# print(result1[0])
+# print(result2)
+# print(groundtruth - result1[0])
+# print(groundtruth - result2)
+
+ # print(result1[0] - result2)
+ # print(bm.allclose(groundtruth, result1[0]))
+ # print(bm.allclose(groundtruth, result2))
+ # assert bm.allclose(result1[0], result2)
+
+ time12 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time13 = time.time()
+ # time.sleep(2)
+
+ time14 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time15 = time.time()
+ # time.sleep(2)
+
+ time16 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time17 = time.time()
+ # time.sleep(2)
+
+ time18 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time19 = time.time()
+
+ time20 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time21 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ brainpy_time1 = (time13 - time12) * 1000
+ brainpy_time2 = (time15 - time14) * 1000
+ brainpy_time3 = (time17 - time16) * 1000
+ brainpy_time4 = (time19 - time18) * 1000
+ brainpy_time5 = (time21 - time20) * 1000
+
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_2: ', taichi_aot_time2, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_4: ', taichi_aot_time4, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_2: ', brainpy_time2, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_4: ', brainpy_time4, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ # assert(jnp.allclose(result1[0], result2))
+
+ speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
+ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
+
+def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel):
+ rng = bm.random.RandomState(seed=seed)
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
+
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ # time.sleep(2)
+
+ time0 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time1 = time.time()
+ # time.sleep(2)
+
+ time2 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time3 = time.time()
+ # time.sleep(2)
+
+ time4 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time5 = time.time()
+ # time.sleep(2)
+
+ time6 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time7 = time.time()
+
+ time8 = time.time()
+ result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time9 = time.time()
+
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+# print(result1[0])
+# print(result2)
+# print(groundtruth - result1[0])
+# print(groundtruth - result2)
+
+ # print(result1[0] - result2)
+ # print(bm.allclose(groundtruth, result1[0]))
+ # print(bm.allclose(groundtruth, result2))
+ # assert bm.allclose(result1[0], result2)
+
+ time12 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time13 = time.time()
+ # time.sleep(2)
+
+ time14 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time15 = time.time()
+ # time.sleep(2)
+
+ time16 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time17 = time.time()
+ # time.sleep(2)
+
+ time18 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time19 = time.time()
+
+ time20 = time.time()
+ result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
+ time21 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ brainpy_time1 = (time13 - time12) * 1000
+ brainpy_time2 = (time15 - time14) * 1000
+ brainpy_time3 = (time17 - time16) * 1000
+ brainpy_time4 = (time19 - time18) * 1000
+ brainpy_time5 = (time21 - time20) * 1000
+
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_2: ', taichi_aot_time2, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_4: ', taichi_aot_time4, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_2: ', brainpy_time2, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_4: ', brainpy_time4, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ # assert(jnp.allclose(result1[0], result2))
+
+ speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
+ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
+
+
+def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel):
+ print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
+ if _type == 'homo':
+ return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel)
+ elif _type == 'uniform':
+ return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel)
+ elif _type == 'normal':
+ return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel)
+ else:
+ raise ValueError
+
+
+def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel):
+ print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
+ if _type == 'homo':
+ return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel)
+ elif _type == 'uniform':
+ return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel)
+ elif _type == 'normal':
+ return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel)
+ else:
+ raise ValueError
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'speedup'])
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup]
+ df.to_csv(f'{PATH}/jitconn_matvec_grad_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _type in types:
+ for _outdim_parallel in outdim_parallel:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel)
+ # append to dataframe
+ df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup]
+ df.to_csv(f'{PATH}/jitconn_matvec_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py
index 016f9b0dd..b10d55d21 100644
--- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py
+++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py
@@ -1,557 +1,520 @@
# -*- coding: utf-8 -*-
+from functools import partial
import jax
import jax.numpy as jnp
from absl.testing import parameterized
-import platform
import brainpy.math as bm
-import pytest
+shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)]
+shapes = [(100, 200), (2, 1000), (1000, 2)]
-is_manual_test = False
-if platform.system() == 'Windows' and not is_manual_test:
- pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True)
-
-shapes = [(100, 200),
- (10, 1000),
- (2, 1000),
- (1000, 10),
- (1000, 2)]
+taichi_mv_prob_homo = bm.jitconn.event_mv_prob_homo
+taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform
+taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal
class Test_event_matvec_prob_conn(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
- print()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.01, 0.1, 0.5],
- homo_data=[-1., ],
- bool_event=[True, False],
- seed=[1234],
- )
- def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False):
- print(f'_test_homo: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
-
- r1 = bm.jitconn.event_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = bm.jitconn.event_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r3 = bm.jitconn.event_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- r3 = jax.block_until_ready(r3)
- self.assertTrue(jnp.allclose(r1, r3))
-
- # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post')
- # indices = bm.as_jax(indices)
- # indptr = bm.as_jax(indptr)
- # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events,
- # shape=shape, transpose=transpose)
- # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.01, 0.1, 0.5],
- bool_event=[True, False],
- seed=[1234],
- )
- def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False):
- print(f'_test_homo_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
- weights = bm.as_jax(rng.random(10))
-
- f1 = jax.vmap(
- lambda event, data: bm.jitconn.event_mv_prob_homo(
- event, data, conn_prob=prob, shape=shape, seed=seed,
- transpose=transpose, outdim_parallel=outdim_parallel
- )
+ def __init__(self, *args, platform='cpu', **kwargs):
+ super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs)
+ bm.set_platform(platform)
+ print()
+
+ @parameterized.product(
+ transpose=[True, False],
+ x64=[True, False],
+ outdim_parallel=[True, False],
+ shape=shapes,
+ prob=[0.01, 0.1, 0.5],
+ homo_data=[-1., ],
+ bool_event=[True, False],
+ seed=[1234],
)
- r1 = f1(events, weights)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, weights)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}',
- shape=shape, transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob, seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, 0.5]
- )
- def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.5
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda event, data: bm.jitconn.event_mv_prob_homo(
- event, data, conn_prob=prob, shape=shape, seed=seed,
- outdim_parallel=outdim_parallel, transpose=transpose
- ).sum(),
- argnums=0
+ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False):
+ print(f'_test_homo: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'homo_data = {homo_data}, '
+ f'bool_event = {bool_event}, '
+ f'x64={x64}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+
+ r1 = taichi_mv_prob_homo(events,
+ homo_data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = taichi_mv_prob_homo(events,
+ homo_data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+
+ # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post')
+ # indices = bm.as_jax(indices)
+ # indptr = bm.as_jax(indptr)
+ # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events,
+ # shape=shape, transpose=transpose)
+ # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ transpose=[True, False],
+ x64=[True, False],
+ outdim_parallel=[True, False],
+ shape=shapes,
+ prob=[0.01, 0.1, 0.5],
+ bool_event=[True, False],
+ seed=[1234],
)
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
-
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
-
- r3 = f1(events, 3.)
- r3 = jax.block_until_ready(r3)
-
- self.assertTrue(jnp.allclose(r1 * 3., r3))
- self.assertTrue(jnp.allclose(r1 * 2., r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_low=w_low,
- w_high=w_high,
- bool_event=bool_event,
- seed=1234,
- x64=x64
- )
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, 0.4]
- for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
- for bool_event in [True, False]
- )
- def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
- bool_event=True, seed=None, x64=False):
- print(f'_test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- r1 = bm.jitconn.event_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = bm.jitconn.event_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r3 = bm.jitconn.event_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- r3 = jax.block_until_ready(r3)
- self.assertTrue(jnp.allclose(r1, r3))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape, transpose=transpose,
- outdim_parallel=outdim_parallel, prob=prob,
- bool_event=bool_event,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_uniform_vmap: '
- f'shape={shape}, '
- f'transpose={transpose}, '
- f'bool_event={bool_event}, '
- f'outdim_parallel={outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for bool_event in [True, False]
- )
- def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob,
- bool_event=True, seed=None, x64=False):
- print(f'_test_uniform_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- f1 = jax.vmap(
- lambda e: bm.jitconn.event_mv_prob_uniform(e,
- w_low=0.,
- w_high=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
+ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False):
+ print(f'_test_homo_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'bool_event = {bool_event}, '
+ f'x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+ weights = bm.as_jax(rng.random(10))
+
+ f1 = jax.vmap(
+ lambda event, data: taichi_mv_prob_homo(
+ event, data, conn_prob=prob, shape=shape, seed=seed,
+ transpose=transpose, outdim_parallel=outdim_parallel
+ )[0]
+ )
+ r1 = f1(events, weights)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events, weights)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'_test_homo_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}',
+ shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob, seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1, 0.5]
)
-
- r1 = f1(events)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- testcase_name=f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda e, w_high: bm.jitconn.event_mv_prob_uniform(
- e,
- w_low=0.,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose).sum()
+ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'_test_homo_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.5
+ events = bm.as_jax(events)
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda event, data: taichi_mv_prob_homo(
+ event, data, conn_prob=prob, shape=shape, seed=seed,
+ outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(),
+ argnums=0
+ )
+ r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
+
+ r3 = f1(events, 3.)
+ r3 = jax.block_until_ready(r3)
+
+ self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6))
+ self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'test_uniform: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}, '
+ f'bool_event = {bool_event}, '
+ f'x64={x64}',
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_low=w_low,
+ w_high=w_high,
+ bool_event=bool_event,
+ seed=1234,
+ x64=x64
+ )
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1, 0.4]
+ for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
+ for bool_event in [True, False]
)
-
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(bm.allclose(r1 * 2., r2))
- # print(r1)
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_mu=w_mu,
- w_sigma=w_sigma,
- bool_event=bool_event,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_normal: '
- f'shape={shape}, '
- f'transpose={transpose}, '
- f'outdim_parallel={outdim_parallel}, '
- f'prob={prob}, '
- f'w_mu={w_mu}, '
- f'w_sigma={w_sigma}, '
- f'bool_event={bool_event}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, ]
- for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
- for bool_event in [True, False]
- )
- def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
- bool_event=True, seed=None, x64=False):
- print(f'_test_normal: shape = {shape}, '
- f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, '
- f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- r1 = bm.jitconn.event_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = bm.jitconn.event_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r3 = bm.jitconn.event_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- r3 = jax.block_until_ready(r3)
- self.assertTrue(jnp.allclose(r1, r3))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- bool_event=bool_event,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_normal_vmap: '
- f'shape={shape}, '
- f'transpose={transpose}, '
- f'outdim_parallel={outdim_parallel}, '
- f'prob={prob}, '
- f'bool_event={bool_event}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for bool_event in [True, False]
- )
- def test_normal_vmap(self, shape, transpose, outdim_parallel, prob,
- bool_event=True, seed=None, x64=False):
- print(f'_test_normal_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e,
- w_mu=0.,
- w_sigma=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose))
- r1 = f1(events)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.jit(
- jax.grad(
- lambda e, w_sigma: bm.jitconn.event_mv_prob_normal(
- e,
- w_mu=0.,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose).sum()
- )
+ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
+ bool_event=True, seed=1234, x64=False):
+ print(f'_test_uniform: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}, '
+ f'x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ r1 = taichi_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = taichi_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel, prob=prob,
+ bool_event=bool_event,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_uniform_vmap: '
+ f'shape={shape}, '
+ f'transpose={transpose}, '
+ f'bool_event={bool_event}, '
+ f'outdim_parallel={outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for bool_event in [True, False]
+ )
+ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob,
+ bool_event=True, seed=1234, x64=False):
+ print(f'_test_uniform_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ f1 = jax.vmap(
+ lambda e: taichi_mv_prob_uniform(e,
+ w_low=0.,
+ w_high=1.,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ )
+
+ r1 = f1(events)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ testcase_name=f'_test_uniform_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'_test_uniform_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda e, w_high: taichi_mv_prob_uniform(
+ e,
+ w_low=0.,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose).sum()
+ )
+
+ r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6))
+ # print(r1)
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ bool_event=bool_event,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_normal: '
+ f'shape={shape}, '
+ f'transpose={transpose}, '
+ f'outdim_parallel={outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_mu={w_mu}, '
+ f'w_sigma={w_sigma}, '
+ f'bool_event={bool_event}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1, ]
+ for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
+ for bool_event in [True, False]
+ )
+ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
+ bool_event=True, seed=1234, x64=False):
+ print(f'_test_normal: shape = {shape}, '
+ f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, '
+ f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ r1 = taichi_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = taichi_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ bool_event=bool_event,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_normal_vmap: '
+ f'shape={shape}, '
+ f'transpose={transpose}, '
+ f'outdim_parallel={outdim_parallel}, '
+ f'prob={prob}, '
+ f'bool_event={bool_event}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for bool_event in [True, False]
+ )
+ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob,
+ bool_event=True, seed=1234, x64=False):
+ print(f'_test_normal_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e,
+ w_mu=0.,
+ w_sigma=1.,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose))
+ r1 = f1(events)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
)
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(bm.allclose(r1 * 2, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
+ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'_test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ events = events.astype(float)
+
+ f1 = jax.jit(
+ jax.grad(
+ lambda e, w_sigma: taichi_mv_prob_normal(
+ e,
+ w_mu=0.,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose).sum()
+ )
+ )
+ r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py
deleted file mode 100644
index 778212547..000000000
--- a/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import jax
-import pytest
-
-import test_event_matvec
-
-if jax.default_backend() != 'gpu':
- pytest.skip("No gpu available.", allow_module_level=True)
-
-
-class Test_event_matvec_prob_conn_GPU(test_event_matvec.Test_event_matvec_prob_conn):
- def __init__(self, *args, **kwargs):
- super(Test_event_matvec_prob_conn_GPU, self).__init__(*args, **kwargs, platform='gpu')
diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py
new file mode 100644
index 000000000..b2fa77229
--- /dev/null
+++ b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py
@@ -0,0 +1,564 @@
+# -*- coding: utf-8 -*-
+from functools import partial
+
+import jax
+import jax.numpy as jnp
+from absl.testing import parameterized
+
+import platform
+import brainpy.math as bm
+
+import pytest
+pytest.skip('Old implementation.', allow_module_level=True)
+is_manual_test = False
+if platform.system() == 'Windows' and not is_manual_test:
+ pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True)
+
+shapes = [(100, 200),
+ # (10, 1000),
+ (2, 1000),
+ # (1000, 10),
+ (1000, 2)]
+
+brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib')
+taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi')
+brainpylib_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='brainpylib')
+taichi_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='taichi')
+brainpylib_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='brainpylib')
+taichi_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='taichi')
+
+class Test_event_matvec_prob_conn(parameterized.TestCase):
+ def __init__(self, *args, platform='cpu', **kwargs):
+ super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs)
+ bm.set_platform(platform)
+ print()
+
+ @parameterized.product(
+ transpose=[True, False],
+ x64=[True, False],
+ outdim_parallel=[True, False],
+ shape=shapes,
+ prob=[0.01, 0.1, 0.5],
+ homo_data=[-1., ],
+ bool_event=[True, False],
+ seed=[1234],
+ )
+ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False):
+ print(f'_test_homo: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'homo_data = {homo_data}, '
+ f'bool_event = {bool_event}, '
+ f'x64={x64}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ if not bool_event:
+ events = events.astype(float)
+
+ r1 = brainpylib_mv_prob_homo(events,
+ homo_data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = brainpylib_mv_prob_homo(events,
+ homo_data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2))
+
+ r3 = brainpylib_mv_prob_homo(events,
+ homo_data,
+ conn_prob=prob,
+ shape=(shape[1], shape[0]),
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=not transpose)
+ r3 = jax.block_until_ready(r3)
+ self.assertTrue(jnp.allclose(r1, r3))
+
+ # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post')
+ # indices = bm.as_jax(indices)
+ # indptr = bm.as_jax(indptr)
+ # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events,
+ # shape=shape, transpose=transpose)
+ # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ transpose=[True, False],
+ x64=[True, False],
+ outdim_parallel=[True, False],
+ shape=shapes,
+ prob=[0.01, 0.1, 0.5],
+ bool_event=[True, False],
+ seed=[1234],
+ )
+ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False):
+ print(f'_test_homo_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'bool_event = {bool_event}, '
+ f'x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+ weights = bm.as_jax(rng.random(10))
+
+ f1 = jax.vmap(
+ lambda event, data: brainpylib_mv_prob_homo(
+ event, data, conn_prob=prob, shape=shape, seed=seed,
+ transpose=transpose, outdim_parallel=outdim_parallel
+ )
+ )
+ r1 = f1(events, weights)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events, weights)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'_test_homo_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}',
+ shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob, seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1, 0.5]
+ )
+ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'_test_homo_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.5
+ events = bm.as_jax(events)
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda event, data: brainpylib_mv_prob_homo(
+ event, data, conn_prob=prob, shape=shape, seed=seed,
+ outdim_parallel=outdim_parallel, transpose=transpose
+ ).sum(),
+ argnums=0
+ )
+ r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
+
+ r3 = f1(events, 3.)
+ r3 = jax.block_until_ready(r3)
+
+ self.assertTrue(jnp.allclose(r1 * 3., r3))
+ self.assertTrue(jnp.allclose(r1 * 2., r2))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'test_uniform: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}, '
+ f'bool_event = {bool_event}, '
+ f'x64={x64}',
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_low=w_low,
+ w_high=w_high,
+ bool_event=bool_event,
+ seed=1234,
+ x64=x64
+ )
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1, 0.4]
+ for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
+ for bool_event in [True, False]
+ )
+ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
+ bool_event=True, seed=None, x64=False):
+ print(f'_test_uniform: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}, '
+ f'x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ r1 = brainpylib_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = brainpylib_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2))
+
+ r3 = brainpylib_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=(shape[1], shape[0]),
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=not transpose)
+ r3 = jax.block_until_ready(r3)
+ self.assertTrue(jnp.allclose(r1, r3))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel, prob=prob,
+ bool_event=bool_event,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_uniform_vmap: '
+ f'shape={shape}, '
+ f'transpose={transpose}, '
+ f'bool_event={bool_event}, '
+ f'outdim_parallel={outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for bool_event in [True, False]
+ )
+ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob,
+ bool_event=True, seed=None, x64=False):
+ print(f'_test_uniform_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ f1 = jax.vmap(
+ lambda e: brainpylib_mv_prob_uniform(e,
+ w_low=0.,
+ w_high=1.,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ )
+
+ r1 = f1(events)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ testcase_name=f'_test_uniform_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'_test_uniform_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda e, w_high: brainpylib_mv_prob_uniform(
+ e,
+ w_low=0.,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose).sum()
+ )
+
+ r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(bm.allclose(r1 * 2., r2))
+ # print(r1)
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ bool_event=bool_event,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_normal: '
+ f'shape={shape}, '
+ f'transpose={transpose}, '
+ f'outdim_parallel={outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_mu={w_mu}, '
+ f'w_sigma={w_sigma}, '
+ f'bool_event={bool_event}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1, ]
+ for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
+ for bool_event in [True, False]
+ )
+ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
+ bool_event=True, seed=None, x64=False):
+ print(f'_test_normal: shape = {shape}, '
+ f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, '
+ f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ r1 = brainpylib_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r1 = jax.block_until_ready(r1)
+
+ r2 = brainpylib_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2))
+
+ r3 = brainpylib_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=(shape[1], shape[0]),
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=not transpose)
+ r3 = jax.block_until_ready(r3)
+ self.assertTrue(jnp.allclose(r1, r3))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ bool_event=bool_event,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_normal_vmap: '
+ f'shape={shape}, '
+ f'transpose={transpose}, '
+ f'outdim_parallel={outdim_parallel}, '
+ f'prob={prob}, '
+ f'bool_event={bool_event}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for bool_event in [True, False]
+ )
+ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob,
+ bool_event=True, seed=None, x64=False):
+ print(f'_test_normal_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
+ events = bm.as_jax(events)
+ if not bool_event:
+ events = events.astype(float)
+
+ f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e,
+ w_mu=0.,
+ w_sigma=1.,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose))
+ r1 = f1(events)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(jnp.allclose(r1, r2))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ x64=x64,
+ seed=1234,
+ testcase_name=f'_test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'_test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}')
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ events = bm.as_jax(events)
+ events = events.astype(float)
+
+ f1 = jax.jit(
+ jax.grad(
+ lambda e, w_sigma: brainpylib_mv_prob_normal(
+ e,
+ w_mu=0.,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose).sum()
+ )
+ )
+ r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
+ r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
+ self.assertTrue(bm.allclose(r1 * 2, r2))
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py
index 91c48fc66..2e6e406cf 100644
--- a/brainpy/_src/math/jitconn/tests/test_matvec.py
+++ b/brainpy/_src/math/jitconn/tests/test_matvec.py
@@ -1,65 +1,61 @@
# -*- coding: utf-8 -*-
+from functools import partial
import jax
import jax.numpy as jnp
from absl.testing import parameterized
import brainpy.math as bm
-import platform
-import pytest
-is_manual_test = False
-if platform.system() == 'Windows' and not is_manual_test:
- pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
+shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)]
+shapes = [(100, 200), (2, 1000), (1000, 2)]
-shapes = [(100, 200),
- (10, 1000),
- (2, 1000),
- (1000, 10),
- (1000, 2)]
+taichi_mv_prob_homo = bm.jitconn.mv_prob_homo
+taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform
+taichi_mv_prob_normal = bm.jitconn.mv_prob_normal
class Test_matvec_prob_conn(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_matvec_prob_conn, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
- print()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_homo, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}, '
- f'x64 = {x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- homo_data=homo_data,
- seed=1234)
- for x64 in [True, False]
- for transpose in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for homo_data in [-1., 1.]
- )
- def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False):
- print(f'test_homo: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = bm.jitconn.mv_prob_homo(vector,
+ def __init__(self, *args, platform='cpu', **kwargs):
+ super(Test_matvec_prob_conn, self).__init__(*args, **kwargs)
+ bm.set_platform(platform)
+ print()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_homo, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'homo_data = {homo_data}, '
+ f'x64 = {x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ homo_data=homo_data,
+ seed=1234)
+ for x64 in [True, False]
+ for transpose in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for homo_data in [-1., 1.]
+ )
+ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False):
+ print(f'test_homo: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'homo_data = {homo_data}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ r1 = taichi_mv_prob_homo(vector,
homo_data,
conn_prob=prob,
shape=shape,
@@ -67,163 +63,152 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non
outdim_parallel=outdim_parallel,
transpose=transpose)
- r2 = bm.jitconn.mv_prob_homo(vector,
+ r2 = taichi_mv_prob_homo(vector,
homo_data,
conn_prob=prob,
shape=shape,
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r2 = bm.jitconn.mv_prob_homo(vector,
- homo_data,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- self.assertTrue(jnp.allclose(r1, r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_homo_vmap, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'test_homo_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
- weights = bm.as_jax(rng.random(10))
-
- f1 = jax.vmap(
- lambda event, data: bm.jitconn.mv_prob_homo(
- event, data,
- conn_prob=prob, shape=shape, seed=seed,
- outdim_parallel=outdim_parallel, transpose=transpose
- )
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_homo_vmap, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
)
- r1 = f1(events, weights)
- r2 = f1(events, weights)
- self.assertTrue(jnp.allclose(r1, r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_homo_grad, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda event, data: bm.jitconn.mv_prob_homo(
- event, data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- ).sum(),
- argnums=0
+ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'test_homo_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
+ weights = bm.as_jax(rng.random(10))
+
+ f1 = jax.vmap(
+ lambda event, data: taichi_mv_prob_homo(
+ event, data,
+ conn_prob=prob, shape=shape, seed=seed,
+ outdim_parallel=outdim_parallel, transpose=transpose
+ )[0]
+ )
+ r1 = f1(events, weights)
+ r2 = f1(events, weights)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_homo_grad, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
)
- r1 = f1(events, 1.)
- r2 = f1(events, 2.)
-
- self.assertTrue(jnp.allclose(r1 * 2., r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_uniform, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}'
- f'x64 = {x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_low=w_low,
- w_high=w_high,
- x64=x64,
- seed=1234)
- for x64 in [True, False]
- for transpose in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
- )
- def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False):
- print(f'test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'x64 = {x64}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = bm.jitconn.mv_prob_uniform(events,
+ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'_test_homo_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda event, data: taichi_mv_prob_homo(
+ event, data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose
+ )[0].sum(),
+ argnums=0
+ )
+ r1 = f1(events, 1.)
+ r2 = f1(events, 2.)
+
+ self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_uniform, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}'
+ f'x64 = {x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_low=w_low,
+ w_high=w_high,
+ x64=x64,
+ seed=1234)
+ for x64 in [True, False]
+ for transpose in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
+ )
+ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False):
+ print(f'test_uniform: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}, '
+ f'x64 = {x64}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ r1 = taichi_mv_prob_uniform(events,
w_low=w_low,
w_high=w_high,
conn_prob=prob,
@@ -232,7 +217,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s
outdim_parallel=outdim_parallel,
transpose=transpose)
- r2 = bm.jitconn.mv_prob_uniform(events,
+ r2 = taichi_mv_prob_uniform(events,
w_low=w_low,
w_high=w_high,
conn_prob=prob,
@@ -240,58 +225,45 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- r2 = bm.jitconn.mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'test_uniform_vmap, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'test_uniform_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
-
- f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e,
+ c = jnp.allclose(r1, r2, atol=1e-6)
+ if not c:
+ print(r1, r2)
+ self.assertTrue(c)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'test_uniform_vmap, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}',
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'test_uniform_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
+
+ f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e,
w_low=0.,
w_high=1.,
conn_prob=prob,
@@ -300,107 +272,107 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None,
outdim_parallel=outdim_parallel,
transpose=transpose))
- r1 = f1(events)
- r2 = f1(events)
- self.assertTrue(jnp.allclose(r1, r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_uniform_grad, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for x64 in [True, False]
- for transpose in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- f1 = jax.grad(
- lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform(
- e,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- ).sum()
+ r1 = f1(events)
+ r2 = f1(events)
+ self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_uniform_grad, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for x64 in [True, False]
+ for transpose in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
)
-
- r1 = f1(events, 0., 1.)
- r2 = f1(events, 0., 2.)
-
- self.assertTrue(bm.allclose(r1 * 2., r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(
- testcase_name=(f'test_normal, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_mu = {w_mu}, '
- f'w_sigma = {w_sigma},'
- f'x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_mu=w_mu,
- w_sigma=w_sigma,
- seed=1234
+ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'_test_uniform_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ f1 = jax.grad(
+ lambda e, w_low, w_high: taichi_mv_prob_uniform(
+ e,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose
+ )[0].sum()
+ )
+
+ r1 = f1(events, 0., 1.)
+ r2 = f1(events, 0., 2.)
+
+ self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name=(f'test_normal, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_mu = {w_mu}, '
+ f'w_sigma = {w_sigma},'
+ f'x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ seed=1234
+ )
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
- )
- def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False):
- print(f'_test_normal: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_mu = {w_mu}, '
- f'w_sigma = {w_sigma}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = bm.jitconn.mv_prob_normal(events,
+ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False):
+ print(f'_test_normal: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_mu = {w_mu}, '
+ f'w_sigma = {w_sigma}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ r1 = taichi_mv_prob_normal(events,
w_mu=w_mu,
w_sigma=w_sigma,
conn_prob=prob,
@@ -409,7 +381,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se
outdim_parallel=outdim_parallel,
transpose=transpose)
- r2 = bm.jitconn.mv_prob_normal(events,
+ r2 = taichi_mv_prob_normal(events,
w_mu=w_mu,
w_sigma=w_sigma,
conn_prob=prob,
@@ -417,59 +389,46 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
+ c = jnp.allclose(r1, r2, atol=1e-6)
+ if not c:
+ print(r1, r2)
+ self.assertTrue(c)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'test_normal_vmap, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}',
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'_test_normal_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
- r2 = bm.jitconn.mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'test_normal_vmap, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_normal_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
-
- f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e,
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
+
+ f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e,
w_mu=0.,
w_sigma=1.,
conn_prob=prob,
@@ -477,65 +436,66 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose))
- r1 = f1(events)
- r2 = f1(events)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64,
- testcase_name=f'test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda e, w_sigma: bm.jitconn.mv_prob_normal(
- e,
- w_mu=0.,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- ).sum()
+ r1 = f1(events)
+ r2 = f1(events)
+ c = jnp.allclose(r1, r2, atol=1e-6)
+ if not c:
+ print(r1, r2)
+ print(r1 - r2)
+ self.assertTrue(c)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64,
+ testcase_name=f'test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
)
- r1 = f1(events, 1.)
- r2 = f1(events, 2.)
- self.assertTrue(bm.allclose(r1 * 2., r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
+ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
+ print(f'_test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda e, w_sigma: taichi_mv_prob_normal(
+ e,
+ w_mu=0.,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose
+ )[0].sum()
+ )
+ r1 = f1(events, 1.)
+ r2 = f1(events, 2.)
+ self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py b/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py
deleted file mode 100644
index f227c0e6a..000000000
--- a/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import jax
-import pytest
-
-import test_matvec
-
-if jax.default_backend() != 'gpu':
- pytest.skip("No gpu available.", allow_module_level=True)
-
-
-class Test_matvec_prob_conn_GPU(test_matvec.Test_matvec_prob_conn):
- def __init__(self, *args, **kwargs):
- super(Test_matvec_prob_conn_GPU, self).__init__(*args, **kwargs, platform='gpu')
diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_old.py b/brainpy/_src/math/jitconn/tests/test_matvec_old.py
new file mode 100644
index 000000000..360711e7b
--- /dev/null
+++ b/brainpy/_src/math/jitconn/tests/test_matvec_old.py
@@ -0,0 +1,551 @@
+# -*- coding: utf-8 -*-
+from functools import partial
+
+import jax
+import jax.numpy as jnp
+from absl.testing import parameterized
+
+import brainpy.math as bm
+import platform
+import pytest
+
+pytest.skip('Old implementation.', allow_module_level=True)
+is_manual_test = False
+if platform.system() == 'Windows' and not is_manual_test:
+ pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
+
+shapes = [(100, 200),
+ (10, 1000),
+ (2, 1000),
+ (1000, 10),
+ (1000, 2)]
+
+brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib')
+taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi')
+brainpylib_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='brainpylib')
+taichi_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='taichi')
+brainpylib_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='brainpylib')
+taichi_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='taichi')
+
+class Test_matvec_prob_conn(parameterized.TestCase):
+ def __init__(self, *args, platform='cpu', **kwargs):
+ super(Test_matvec_prob_conn, self).__init__(*args, **kwargs)
+ bm.set_platform(platform)
+ print()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_homo, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'homo_data = {homo_data}, '
+ f'x64 = {x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ homo_data=homo_data,
+ seed=1234)
+ for x64 in [True, False]
+ for transpose in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for homo_data in [-1., 1.]
+ )
+ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False):
+ print(f'test_homo: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'homo_data = {homo_data}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ r1 = brainpylib_mv_prob_homo(vector,
+ homo_data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+
+ r2 = brainpylib_mv_prob_homo(vector,
+ homo_data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ self.assertTrue(jnp.allclose(r1, r2))
+
+ r2 = brainpylib_mv_prob_homo(vector,
+ homo_data,
+ conn_prob=prob,
+ shape=(shape[1], shape[0]),
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=not transpose)
+ self.assertTrue(jnp.allclose(r1, r2))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_homo_vmap, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'test_homo_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
+ weights = bm.as_jax(rng.random(10))
+
+ f1 = jax.vmap(
+ lambda event, data: brainpylib_mv_prob_homo(
+ event, data,
+ conn_prob=prob, shape=shape, seed=seed,
+ outdim_parallel=outdim_parallel, transpose=transpose
+ )
+ )
+ r1 = f1(events, weights)
+ r2 = f1(events, weights)
+ self.assertTrue(jnp.allclose(r1, r2))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_homo_grad, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'_test_homo_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda event, data: brainpylib_mv_prob_homo(
+ event, data,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose
+ ).sum(),
+ argnums=0
+ )
+ r1 = f1(events, 1.)
+ r2 = f1(events, 2.)
+
+ self.assertTrue(jnp.allclose(r1 * 2., r2))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_uniform, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}'
+ f'x64 = {x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_low=w_low,
+ w_high=w_high,
+ x64=x64,
+ seed=1234)
+ for x64 in [True, False]
+ for transpose in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
+ )
+ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False):
+ print(f'test_uniform: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_low = {w_low}, '
+ f'w_high = {w_high}, '
+ f'x64 = {x64}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ r1 = brainpylib_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+
+ r2 = brainpylib_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ c = jnp.allclose(r1, r2)
+ if not c:
+ print(r1, r2)
+ self.assertTrue(c)
+
+ r2 = brainpylib_mv_prob_uniform(events,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=(shape[1], shape[0]),
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=not transpose)
+ c = jnp.allclose(r1, r2)
+ if not c:
+ print(r1, r2)
+ self.assertTrue(c)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'test_uniform_vmap, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, x64={x64}',
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'test_uniform_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
+
+ f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e,
+ w_low=0.,
+ w_high=1.,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose))
+
+ r1 = f1(events)
+ r2 = f1(events)
+ self.assertTrue(jnp.allclose(r1, r2))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=(f'test_uniform_grad, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64)
+ for x64 in [True, False]
+ for transpose in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'_test_uniform_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ f1 = jax.grad(
+ lambda e, w_low, w_high: brainpylib_mv_prob_uniform(
+ e,
+ w_low=w_low,
+ w_high=w_high,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose
+ ).sum()
+ )
+
+ r1 = f1(events, 0., 1.)
+ r2 = f1(events, 0., 2.)
+
+ self.assertTrue(bm.allclose(r1 * 2., r2))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name=(f'test_normal, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_mu = {w_mu}, '
+ f'w_sigma = {w_sigma},'
+ f'x64={x64}'),
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ seed=1234
+ )
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
+ )
+ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False):
+ print(f'_test_normal: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'w_mu = {w_mu}, '
+ f'w_sigma = {w_sigma}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
+
+ r1 = brainpylib_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+
+ r2 = brainpylib_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose)
+ c = jnp.allclose(r1, r2)
+ if not c:
+ print(r1, r2)
+ self.assertTrue(c)
+
+ r2 = brainpylib_mv_prob_normal(events,
+ w_mu=w_mu,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=(shape[1], shape[0]),
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=not transpose)
+ c = jnp.allclose(r1, r2)
+ if not c:
+ print(r1, r2)
+ self.assertTrue(c)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'test_normal_vmap, shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}',
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234)
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'_test_normal_vmap: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
+
+ f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e,
+ w_mu=0.,
+ w_sigma=1.,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose))
+ r1 = f1(events)
+ r2 = f1(events)
+ c = jnp.allclose(r1, r2)
+ if not c:
+ print(r1, r2)
+ self.assertTrue(c)
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel,
+ prob=prob,
+ seed=1234,
+ x64=x64,
+ testcase_name=f'test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}, '
+ f'x64={x64}')
+ for transpose in [True, False]
+ for x64 in [True, False]
+ for outdim_parallel in [True, False]
+ for shape in shapes
+ for prob in [0.01, 0.1]
+ )
+ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
+ print(f'_test_normal_grad: '
+ f'shape = {shape}, '
+ f'transpose = {transpose}, '
+ f'outdim_parallel = {outdim_parallel}, '
+ f'prob={prob}')
+
+ if x64:
+ bm.enable_x64()
+ rng = bm.random.RandomState()
+ events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
+ events = events.astype(float)
+
+ f1 = jax.grad(
+ lambda e, w_sigma: brainpylib_mv_prob_normal(
+ e,
+ w_mu=0.,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose
+ ).sum()
+ )
+ r1 = f1(events, 1.)
+ r2 = f1(events, 2.)
+ print('r1:', r1)
+ print('r2:', r2)
+ self.assertTrue(bm.allclose(r1 * 2., r2))
+
+ if x64:
+ bm.disable_x64()
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py
index b5d12d9ce..cf2b2343d 100644
--- a/brainpy/_src/math/ndarray.py
+++ b/brainpy/_src/math/ndarray.py
@@ -9,9 +9,11 @@
from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node_class
-import brainpy.math
from brainpy.errors import MathError
+bm = None
+
+
__all__ = [
'Array', 'ndarray', 'JaxArray', # alias of Array
'ShardedArray',
@@ -79,7 +81,7 @@ class Array(object):
"""
- __slots__ = ('_value', '_keep_sharding')
+ __slots__ = ('_value', )
def __init__(self, value, dtype: Any = None):
# array value
@@ -132,7 +134,7 @@ def value(self, value):
if value.dtype != self_value.dtype:
raise MathError(f"The dtype of the original data is {self_value.dtype}, "
f"while we got {value.dtype}.")
- self._value = value.value if isinstance(value, Array) else value
+ self._value = value
def update(self, value):
"""Update the value of this Array.
@@ -1039,7 +1041,9 @@ def __jax_array__(self):
def as_variable(self):
"""As an instance of Variable."""
- return brainpy.math.Variable(self)
+ global bm
+ if bm is None: from brainpy import math as bm
+ return bm.Variable(self)
def __format__(self, specification):
return self.value.__format__(specification)
@@ -1473,7 +1477,9 @@ def fill_(self, value):
return self
def uniform_(self, low=0., high=1.):
- self.value = brainpy.math.random.uniform(low, high, self.shape)
+ global bm
+ if bm is None: from brainpy import math as bm
+ self.value = bm.random.uniform(low, high, self.shape)
return self
def log_normal_(self, mean=1, std=2):
@@ -1489,14 +1495,18 @@ def log_normal_(self, mean=1, std=2):
mean: the mean value.
std: the standard deviation.
"""
- self.value = brainpy.math.random.lognormal(mean, std, self.shape)
+ global bm
+ if bm is None: from brainpy import math as bm
+ self.value = bm.random.lognormal(mean, std, self.shape)
return self
def normal_(self, ):
"""
Fills self tensor with elements samples from the normal distribution parameterized by mean and std.
"""
- self.value = brainpy.math.random.randn(*self.shape)
+ global bm
+ if bm is None: from brainpy import math as bm
+ self.value = bm.random.randn(*self.shape)
return self
def cuda(self):
@@ -1549,11 +1559,12 @@ def value(self):
Returns:
The stored data.
"""
+ v = self._value
# keep sharding constraints
- if self._keep_sharding and hasattr(self._value, 'sharding') and (self._value.sharding is not None):
- return jax.lax.with_sharding_constraint(self._value, self._value.sharding)
+ if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None):
+ return jax.lax.with_sharding_constraint(v, v.sharding)
# return the value
- return self._value
+ return v
@value.setter
def value(self, value):
@@ -1574,6 +1585,6 @@ def value(self, value):
if value.dtype != self_value.dtype:
raise MathError(f"The dtype of the original data is {self_value.dtype}, "
f"while we got {value.dtype}.")
- self._value = value.value if isinstance(value, Array) else value
+ self._value = value
diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py
index 6122f6cd8..ad8a5ccf6 100644
--- a/brainpy/_src/math/object_transform/autograd.py
+++ b/brainpy/_src/math/object_transform/autograd.py
@@ -6,6 +6,7 @@
import jax
import numpy as np
+
if jax.__version__ >= '0.4.16':
from jax.extend import linear_util
else:
@@ -15,35 +16,25 @@
from jax._src.api import (_vjp, _jvp)
from jax.api_util import argnums_partial
from jax.interpreters import xla
-from jax.tree_util import (
- tree_flatten, tree_unflatten,
- tree_map, tree_transpose,
- tree_structure
-)
+from jax.tree_util import (tree_flatten, tree_unflatten,
+ tree_map, tree_transpose,
+ tree_structure)
from jax.util import safe_map
from brainpy import tools, check
from brainpy._src.math.ndarray import Array, _as_jax_array_
-from .tools import (
- dynvar_deprecation,
- node_deprecation,
- get_stack_cache,
- cache_stack,
-)
-from .base import (
- BrainPyObject,
- ObjectTransform
-)
-from .variables import (
- Variable,
- VariableStack,
- current_transform_number,
- new_transform,
-)
+from .tools import (dynvar_deprecation,
+ node_deprecation,
+ get_stack_cache,
+ cache_stack)
+from .base import (BrainPyObject, ObjectTransform)
+from .variables import (Variable, VariableStack)
+from .tools import eval_shape
__all__ = [
'grad', # gradient of scalar function
'vector_grad', # gradient of vector/matrix/...
+ 'functional_vector_grad',
'jacobian', 'jacrev', 'jacfwd', # gradient of jacobian
'hessian', # gradient of hessian
]
@@ -210,36 +201,21 @@ def __call__(self, *args, **kwargs):
elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
- with new_transform(self):
- with VariableStack() as stack:
- if current_transform_number() > 1:
- rets = self._transform(
- [v.value for v in self._grad_vars], # variables for gradients
- {}, # dynamical variables
- *args,
- **kwargs
- )
- else:
- rets = jax.eval_shape(
- self._transform,
- [v.value for v in self._grad_vars], # variables for gradients
- {}, # dynamical variables
- *args,
- **kwargs
- )
+ with VariableStack() as stack:
+ rets = eval_shape(self._transform,
+ [v.value for v in self._grad_vars], # variables for gradients
+ {}, # dynamical variables
+ *args,
+ **kwargs)
cache_stack(self.target, stack)
- self._dyn_vars = stack
- self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
- self._eval_dyn_vars = True
+ self._dyn_vars = stack
+ self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
+ self._eval_dyn_vars = True
- # if not the outermost transformation
- if current_transform_number():
- return self._return(rets)
- else:
- self._dyn_vars = stack
- self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
- self._eval_dyn_vars = True
+ # if not the outermost transformation
+ if not stack.is_first_stack():
+ return self._return(rets)
rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
@@ -466,7 +442,8 @@ def _std_basis(pytree):
return _unravel_array_into_pytree(pytree, 1, flat_basis)
-_isleaf = lambda x: isinstance(x, Array)
+def _isleaf(x):
+ return isinstance(x, Array)
def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
@@ -594,9 +571,6 @@ def jacrev(
def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
_check_callable(fun)
- if has_aux and jax.__version__ < '0.2.28':
- raise NotImplementedError(f'"has_aux" only supported in jax>=0.2.28, but we detect '
- f'the current jax version is {jax.__version__}')
@wraps(fun)
def jacfun(*args, **kwargs):
@@ -769,7 +743,7 @@ def hessian(
return_value=return_value)
-def _vector_grad(func, argnums=0, return_value=False, has_aux=False):
+def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
_check_callable(func)
@wraps(func)
@@ -866,7 +840,7 @@ def vector_grad(
if func is None:
return lambda f: GradientTransform(target=f,
- transform=_vector_grad,
+ transform=functional_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
@@ -875,7 +849,7 @@ def vector_grad(
has_aux=False if has_aux is None else has_aux)
else:
return GradientTransform(target=func,
- transform=_vector_grad,
+ transform=functional_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py
index 25db8095f..c52845a06 100644
--- a/brainpy/_src/math/object_transform/base.py
+++ b/brainpy/_src/math/object_transform/base.py
@@ -6,7 +6,6 @@
"""
import numbers
-import os
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional
@@ -14,14 +13,13 @@
import jax
import numpy as np
-from brainpy import errors
+from brainpy._src.math.modes import Mode
from brainpy._src.math.ndarray import (Array, )
from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector)
from brainpy._src.math.object_transform.naming import (get_unique_name,
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict)
-from brainpy._src.math.modes import Mode
from brainpy._src.math.sharding import BATCH_AXIS
variable_ = None
@@ -328,7 +326,7 @@ def vars(
nodes = self.nodes(method=method, level=level, include_self=include_self)
gather = ArrayCollector()
for node_path, node in nodes.items():
- for k in dir(node):
+ for k in node.__dict__.keys():
if k in node._excluded_vars:
continue
v = getattr(node, k)
diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py
index 39032da84..3edeb08e8 100644
--- a/brainpy/_src/math/object_transform/controls.py
+++ b/brainpy/_src/math/object_transform/controls.py
@@ -1,37 +1,32 @@
# -*- coding: utf-8 -*-
+
import functools
-from typing import Union, Sequence, Any, Dict, Callable, Optional
import numbers
+from typing import Union, Sequence, Any, Dict, Callable, Optional
import jax
import jax.numpy as jnp
from jax.errors import UnexpectedTracerError
+from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten, tree_unflatten
from tqdm.auto import tqdm
-from jax.experimental.host_callback import id_tap
from brainpy import errors, tools
from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.ndarray import (Array, )
-from .tools import (
- evaluate_dyn_vars,
- evaluate_dyn_vars_with_cache,
- dynvar_deprecation,
- node_deprecation,
- abstract
-)
+from brainpy._src.math.ndarray import (Array, _as_jax_array_)
from .base import BrainPyObject, ObjectTransform
from .naming import (
get_unique_name,
get_stack_cache,
cache_stack
)
-from .variables import (
- Variable,
- VariableStack,
- new_transform,
- current_transform_number,
+from .tools import (
+ eval_shape,
+ dynvar_deprecation,
+ node_deprecation,
+ abstract
)
+from .variables import (Variable, VariableStack)
__all__ = [
'make_loop',
@@ -41,6 +36,7 @@
'cond',
'ifelse',
'for_loop',
+ 'scan',
'while_loop',
]
@@ -421,11 +417,27 @@ def call(pred, x=None):
return ControlObject(call, dyn_vars, repr_fun={'true_fun': true_fun, 'false_fun': false_fun})
+@functools.cache
+def _warp(f):
+ @functools.wraps(f)
+ def new_f(*args, **kwargs):
+ return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
+
+ return new_f
+
+
+def _warp_data(data):
+ def new_f(*args, **kwargs):
+ return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
+
+ return new_f
+
+
def _check_f(f):
if callable(f):
- return f
+ return _warp(f)
else:
- return (lambda *args, **kwargs: f)
+ return _warp_data(f)
def _check_sequence(a):
@@ -525,15 +537,13 @@ def cond(
node_deprecation(child_objs)
dyn_vars = get_stack_cache((true_fun, false_fun))
- if not jax.config.jax_disable_jit:
- if dyn_vars is None:
- with new_transform('cond'):
- dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1)
- dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1)
- dyn_vars = dyn_vars1 + dyn_vars2
- cache_stack((true_fun, false_fun), dyn_vars)
- if current_transform_number() > 0:
- return rets
+ if not jax.config.jax_disable_jit and dyn_vars is None:
+ with VariableStack() as dyn_vars:
+ rets = eval_shape(true_fun, *operands, with_stack=True)[1]
+ _ = eval_shape(false_fun, *operands, with_stack=True)
+ cache_stack((true_fun, false_fun), dyn_vars)
+ if not dyn_vars.is_first_stack():
+ return rets
dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
for k in dyn_values.keys():
@@ -557,7 +567,7 @@ def _if_else_return2(conditions, branches):
return branches[-1]
-def all_equal(iterator):
+def _all_equal(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
@@ -664,22 +674,17 @@ def ifelse(
else:
dyn_vars = get_stack_cache(tuple(branches))
if dyn_vars is None:
- with new_transform('ifelse'):
- with VariableStack() as dyn_vars:
- if current_transform_number() > 1:
- rets = [branch(*operands) for branch in branches]
- else:
- rets = [jax.eval_shape(branch, *operands) for branch in branches]
- trees = [jax.tree_util.tree_structure(ret) for ret in rets]
- if not all_equal(trees):
- msg = 'All returns in branches should have the same tree structure. But we got:\n'
- for tree in trees:
- msg += f'- {tree}\n'
- raise TypeError(msg)
+ with VariableStack() as dyn_vars:
+ rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches]
+ trees = [jax.tree_util.tree_structure(ret) for ret in rets]
+ if not _all_equal(trees):
+ msg = 'All returns in branches should have the same tree structure. But we got:\n'
+ for tree in trees:
+ msg += f'- {tree}\n'
+ raise TypeError(msg)
cache_stack(tuple(branches), dyn_vars)
- if current_transform_number():
- return _if_else_return2(conditions, rets)
-
+ if not dyn_vars.is_first_stack():
+ return rets[0]
branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches]
code_scope = {'conditions': conditions, 'branches': branches}
@@ -716,6 +721,7 @@ def _get_for_loop_transform(
unroll: int,
unroll_kwargs: tools.DotDict
):
+ @functools.wraps(body_fun)
def fun2scan(carry, x):
for k in dyn_vars.keys():
dyn_vars[k]._value = carry[k]
@@ -856,35 +862,30 @@ def for_loop(
if not isinstance(operands, (list, tuple)):
operands = (operands,)
- num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]])
bar = None
if progress_bar:
+ num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]])
bar = tqdm(total=num_total)
if jit is None: # jax disable jit
jit = not jax.config.jax_disable_jit
- dyn_vars = get_stack_cache((body_fun, unroll_kwargs))
+ stack = get_stack_cache((body_fun, unroll_kwargs))
if jit:
- if dyn_vars is None:
+ if stack is None:
+ transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar,
+ remat, reverse, unroll, unroll_kwargs)
# TODO: better cache mechanism?
- with new_transform('for_loop'):
- with VariableStack() as dyn_vars:
- transform = _get_for_loop_transform(body_fun, VariableStack(), bar,
- progress_bar, remat, reverse, unroll,
- unroll_kwargs)
- if current_transform_number() > 1:
- rets = transform(operands)
- else:
- rets = jax.eval_shape(transform, operands)
- cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache
- if current_transform_number():
+ with VariableStack() as stack:
+ rets = eval_shape(transform, operands)
+ cache_stack((body_fun, unroll_kwargs), stack) # cache
+ if not stack.is_first_stack():
return rets[1]
del rets
else:
- dyn_vars = VariableStack()
+ stack = VariableStack()
# TODO: cache mechanism?
- transform = _get_for_loop_transform(body_fun, dyn_vars, bar,
+ transform = _get_for_loop_transform(body_fun, stack, bar,
progress_bar, remat, reverse,
unroll, unroll_kwargs)
if jit:
@@ -892,13 +893,128 @@ def for_loop(
else:
with jax.disable_jit():
dyn_vals, out_vals = transform(operands)
- for key in dyn_vars.keys():
- dyn_vars[key]._value = dyn_vals[key]
+ for key in stack.keys():
+ stack[key]._value = dyn_vals[key]
if progress_bar:
bar.close()
+ del dyn_vals, stack
return out_vals
+def _get_scan_transform(
+ body_fun: Callable,
+ dyn_vars: VariableStack,
+ bar: tqdm,
+ progress_bar: bool,
+ remat: bool,
+ reverse: bool,
+ unroll: int,
+):
+ def fun2scan(carry, x):
+ dyn_vars_data, carry = carry
+ for k in dyn_vars.keys():
+ dyn_vars[k]._value = dyn_vars_data[k]
+ carry, results = body_fun(carry, x)
+ if progress_bar:
+ id_tap(lambda *arg: bar.update(), ())
+ carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
+ return (dyn_vars.dict_data(), carry), results
+
+ if remat:
+ fun2scan = jax.checkpoint(fun2scan)
+
+ def call(init, operands):
+ init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
+ return jax.lax.scan(f=fun2scan,
+ init=(dyn_vars.dict_data(), init),
+ xs=operands,
+ reverse=reverse,
+ unroll=unroll)
+
+ return call
+
+
+def scan(
+ body_fun: Callable,
+ init: Any,
+ operands: Any,
+ reverse: bool = False,
+ unroll: int = 1,
+ remat: bool = False,
+ progress_bar: bool = False,
+):
+ """``scan`` control flow with :py:class:`~.Variable`.
+
+ Similar to ``jax.lax.scan``.
+
+ .. versionadded:: 2.4.7
+
+ All returns in body function will be gathered
+ as the return of the whole loop.
+
+ Parameters
+ ----------
+ body_fun: callable
+ A Python function to be scanned. This function accepts one argument and returns one output.
+ The argument denotes a slice of ``operands`` along its leading axis, and that
+ output represents a slice of the return value.
+ init: Any
+ An initial loop carry value of type ``c``, which can be a scalar, array, or any pytree
+ (nested Python tuple/list/dict) thereof, representing the initial loop carry value.
+ This value must have the same structure as the first element of the pair returned
+ by ``body_fun``.
+ operands: Any
+ The value over which to scan along the leading axis,
+ where ``operands`` can be an array or any pytree (nested Python
+ tuple/list/dict) thereof with consistent leading axis sizes.
+ If body function `body_func` receives multiple arguments,
+ `operands` should be a tuple/list whose length is equal to the
+ number of arguments.
+ remat: bool
+ Make ``fun`` recompute internal linearization points when differentiated.
+ reverse: bool
+ Optional boolean specifying whether to run the scan iteration
+ forward (the default) or in reverse, equivalent to reversing the leading
+ axes of the arrays in both ``xs`` and in ``ys``.
+ unroll: int
+ Optional positive int specifying, in the underlying operation of the
+ scan primitive, how many scan iterations to unroll within a single
+ iteration of a loop.
+ progress_bar: bool
+ Whether we use the progress bar to report the running progress.
+
+ .. versionadded:: 2.4.2
+
+ Returns
+ -------
+ outs: Any
+ The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
+ """
+ bar = None
+ if progress_bar:
+ num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]])
+ bar = tqdm(total=num_total)
+
+ stack = get_stack_cache(body_fun)
+ if not jax.config.jax_disable_jit and stack is None:
+ transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll)
+ with VariableStack() as stack:
+ rets = eval_shape(transform, init, operands)
+ cache_stack(body_fun, stack) # cache
+ if not stack.is_first_stack():
+ return rets[0][1], rets[1]
+ del rets
+
+ stack = VariableStack() if stack is None else stack
+ transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll)
+ (dyn_vals, carry), out_vals = transform(init, operands)
+ for key in stack.keys():
+ stack[key]._value = dyn_vals[key]
+ if progress_bar:
+ bar.close()
+ return carry, out_vals
+
+
def _get_while_transform(cond_fun, body_fun, dyn_vars):
def _body_fun(op):
dyn_vals, old_vals = op
@@ -992,7 +1108,6 @@ def while_loop(
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
-
"""
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
@@ -1000,18 +1115,16 @@ def while_loop(
if not isinstance(operands, (list, tuple)):
operands = (operands,)
- dyn_vars = get_stack_cache((body_fun, cond_fun))
- if not jax.config.jax_disable_jit:
- if dyn_vars is None:
- with new_transform('while_loop'):
- dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1)
- dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1)
- dyn_vars = dyn_vars1 + dyn_vars2
- cache_stack((body_fun, cond_fun), dyn_vars)
- if current_transform_number():
- return rets
- dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
- dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands)
- for k, v in dyn_vars.items():
+ stack = get_stack_cache((body_fun, cond_fun))
+ if not jax.config.jax_disable_jit and stack is None:
+ with VariableStack() as stack:
+ _ = eval_shape(cond_fun, *operands, with_stack=True)
+ rets = eval_shape(body_fun, *operands, with_stack=True)[1]
+ cache_stack((body_fun, cond_fun), stack)
+ if not stack.is_first_stack():
+ return rets
+ stack = VariableStack() if stack is None else stack
+ dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands)
+ for k, v in stack.items():
v._value = dyn_values[k]
return out
diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py
index 7bb36f4e2..73eab2f91 100644
--- a/brainpy/_src/math/object_transform/jit.py
+++ b/brainpy/_src/math/object_transform/jit.py
@@ -11,23 +11,15 @@
from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable
import jax
-from jax.sharding import Sharding
from brainpy import tools, check
-from .tools import (dynvar_deprecation,
- node_deprecation,
- evaluate_dyn_vars_with_cache,
- evaluate_dyn_vars,
- _partial_fun)
from .base import BrainPyObject, ObjectTransform
from .naming import get_stack_cache, cache_stack
+from .tools import (dynvar_deprecation,
+ node_deprecation,
+ eval_shape)
+from .variables import (Variable, VariableStack)
from ..ndarray import Array
-from .variables import (Variable,
- VariableStack,
- outermost_transform,
- transform_stack,
- current_transform_number,
- new_transform)
RandomState = None
@@ -151,16 +143,12 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs):
return changes, out
def _get_transform(self, *args, **kwargs):
- with new_transform(self):
- self._dyn_vars, rets = evaluate_dyn_vars(
- self.fun,
- *args,
- static_argnums=self._static_argnums,
- static_argnames=self._static_argnames,
- use_eval_shape=current_transform_number() <= 1,
- **kwargs
- )
-
+ with VariableStack() as self._dyn_vars:
+ rets = eval_shape(self.fun,
+ *args,
+ **kwargs,
+ static_argnums=self._static_argnums,
+ static_argnames=self._static_argnames)
# in_shardings
if self._in_shardings is None:
in_shardings = None
@@ -186,18 +174,18 @@ def _get_transform(self, *args, **kwargs):
_dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState))
out_shardings = (_dyn_vars_sharing,) + out_shardings
- # jit
- self._transform = jax.jit(
- self._transform_function,
- static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
- static_argnames=self._static_argnames,
- donate_argnums=self._donate_argnums,
- inline=self._inline,
- keep_unused=self._keep_unused,
- abstracted_axes=self._abstracted_axes,
- in_shardings=in_shardings,
- out_shardings=out_shardings,
- )
+ # jit
+ self._transform = jax.jit(
+ self._transform_function,
+ static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
+ static_argnames=self._static_argnames,
+ donate_argnums=self._donate_argnums,
+ inline=self._inline,
+ keep_unused=self._keep_unused,
+ abstracted_axes=self._abstracted_axes,
+ in_shardings=in_shardings,
+ out_shardings=out_shardings,
+ )
return rets
def __call__(self, *args, **kwargs):
@@ -207,7 +195,7 @@ def __call__(self, *args, **kwargs):
if self._transform is None: # initialize the transformation
rets = self._get_transform(*args, **kwargs)
# if not the outermost transformation
- if current_transform_number():
+ if not self._dyn_vars.is_first_stack():
return rets
# call the transformed function
@@ -477,15 +465,8 @@ def call_fun(self, *args, **kwargs):
cache = get_stack_cache(hash_v) # TODO: better cache mechanism
if cache is None:
fun2 = partial(fun, self)
-
- with jax.ensure_compile_time_eval():
- if len(static_argnums) or len(static_argnames):
- fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames)
- else:
- args_, kwargs_, fun3 = args, kwargs, fun2
- with VariableStack() as stack:
- _ = jax.eval_shape(fun3, *args_, **kwargs_)
- del args_, kwargs_
+ with VariableStack() as stack:
+ _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
_transform = jax.jit(
_make_transform(fun2, stack),
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums),
diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py
index 1c8ca6ef9..1181e003b 100644
--- a/brainpy/_src/math/object_transform/naming.py
+++ b/brainpy/_src/math/object_transform/naming.py
@@ -41,7 +41,7 @@ def get_unique_name(type_: str):
return name
-def clear_name_cache(ignore_warn=False):
+def clear_name_cache(ignore_warn=True):
"""Clear the cached names."""
_name2id.clear()
_typed_names.clear()
@@ -57,6 +57,7 @@ def cache_stack(func, stack):
def clear_stack_cache():
+ """Clear the cached stack."""
for k in tuple(_fun2stack.keys()):
del _fun2stack[k]
diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py
deleted file mode 100644
index 1eddce048..000000000
--- a/brainpy/_src/math/object_transform/parallels.py
+++ /dev/null
@@ -1,460 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""
-The parallel compilation tools for JAX backend.
-
-1. Vectorize compilation is implemented by the 'vmap()' function
-2. Parallel compilation is implemented by the 'pmap()' function
-
-"""
-
-
-import functools
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-from jax.interpreters.partial_eval import DynamicJaxprTracer
-from jax.interpreters.partial_eval import JaxprTracer
-from jax.interpreters.pxla import ShardedDeviceArray
-
-try:
- from jax.errors import UnexpectedTracerError
-except ImportError:
- from jax.core import UnexpectedTracerError
-
-from brainpy import errors
-from brainpy._src.math.random import RandomState
-from brainpy._src.math.ndarray import Array
-from brainpy.tools import change_func_name
-from .base import BrainPyObject, ArrayCollector
-
-__all__ = [
- 'vmap',
- 'pmap',
-]
-
-
-def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes,
- batch_idx, axis_name, f_name=None):
- @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name)
- def vmapped_func(nonbatched_data, batched_data, *args, **kwargs):
- nonbatched_vars.assign(nonbatched_data)
- batched_vars.assign(batched_data)
- out = func(*args, **kwargs)
- nonbatched_changes = nonbatched_vars.dict()
- batched_changes = batched_vars.dict()
- return nonbatched_changes, batched_changes, out
-
- def call(*args, **kwargs):
- n = args[batch_idx[0]].shape[batch_idx[1]]
- nonbatched_data = nonbatched_vars.dict()
- batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()}
- try:
- out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs)
- except UnexpectedTracerError as e:
- nonbatched_vars.assign(nonbatched_data)
- batched_vars.assign(batched_data)
- raise errors.JaxTracerError() from e
- # for key, v in dyn_changes.items():
- # dyn_vars[key] = reduce_func(v)
- # for key, v in rand_changes.items():
- # rand_vars[key] = reduce_func(v)
- return out
-
- return change_func_name(name=f_name, f=call) if f_name else call
-
-
-def vmap(func, dyn_vars=None, batched_vars=None,
- in_axes=0, out_axes=0, axis_name=None,
- reduce_func=None, auto_infer=False):
- """Vectorization compilation for class objects.
-
- Vectorized compile a function or a module to run in parallel on a single device.
-
- Examples
- --------
-
- Parameters
- ----------
- func : BrainPyObject, function, callable
- The function or the module to compile.
- dyn_vars : dict, sequence
- batched_vars : dict
- in_axes : optional, int, sequence of int
- Specify which input array axes to map over. If each positional argument to
- ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None,
- or a tuple of integers and Nones with length equal to the number of
- positional arguments to ``obj_or_func``. An integer or ``None``
- indicates which array axis to map over for all arguments (with ``None``
- indicating not to map any axis), and a tuple indicates which axis to map
- for each corresponding positional argument. Axis integers must be in the
- range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
- dimensions (axes) of the corresponding input array.
-
- If the positional arguments to ``obj_or_func`` are container types, the
- corresponding element of ``in_axes`` can itself be a matching container,
- so that distinct array axes can be mapped for different container
- elements. ``in_axes`` must be a container tree prefix of the positional
- argument tuple passed to ``obj_or_func``.
-
- At least one positional argument must have ``in_axes`` not None. The sizes
- of the mapped input axes for all mapped positional arguments must all be
- equal.
-
- Arguments passed as keywords are always mapped over their leading axis
- (i.e. axis index 0).
- out_axes : optional, int, tuple/list/dict
- Indicate where the mapped axis should appear in the output. All outputs
- with a mapped axis must have a non-None ``out_axes`` specification. Axis
- integers must be in the range ``[-ndim, ndim)`` for each output array,
- where ``ndim`` is the number of dimensions (axes) of the array returned
- by the :func:`vmap`-ed function, which is one more than the number of
- dimensions (axes) of the corresponding array returned by ``obj_or_func``.
- axis_name : optional
-
- Returns
- -------
- obj_or_func : Any
- Batched/vectorized version of ``obj_or_func`` with arguments that correspond to
- those of ``obj_or_func``, but with extra array axes at positions indicated by
- ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but
- with extra array axes at positions indicated by ``out_axes``.
-
- """
- # if isinstance(func, DynamicalSystem):
- # if len(func.steps): # DynamicalSystem has step functions
- #
- # # dynamical variables
- # dyn_vars = (dyn_vars or func.vars().unique())
- # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector()
- # for key, val in dyn_vars.items():
- # if isinstance(val, RandomState):
- # rand_vars[key] = val
- # else:
- # dyn_vars[key] = val
- #
- # # in axes
- # if in_axes is None:
- # in_axes = {key: (None, 0) for key in func.steps.keys()}
- # elif isinstance(in_axes, int):
- # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()}
- # elif isinstance(in_axes, (tuple, list)):
- # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()}
- # elif isinstance(in_axes, dict):
- # keys = list(func.steps.keys())
- # if keys[0] not in in_axes:
- # in_axes = {key: (None, 0, in_axes) for key in keys}
- # else:
- # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys}
- # assert isinstance(in_axes, dict)
- #
- # # batch size index
- # batch_idx = {}
- # for key, axes in in_axes.items():
- # for i, axis in enumerate(axes[2:]):
- # if axis is not None:
- # batch_idx[key] = (i, axis)
- # break
- # else:
- # raise ValueError(f'Found no batch axis: {axes}.')
- #
- # # out axes
- # if out_axes is None:
- # out_axes = {key: 0 for key in func.steps.keys()}
- # elif isinstance(out_axes, int):
- # out_axes = {key: out_axes for key in func.steps.keys()}
- # elif isinstance(out_axes, (tuple, list)):
- # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()}
- # elif isinstance(out_axes, dict):
- # keys = list(func.steps.keys())
- # if keys[0] not in out_axes:
- # out_axes = {key: (out_axes, 0, 0) for key in keys}
- # else:
- # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys}
- # assert isinstance(out_axes, dict)
- #
- # # reduce_func
- # if reduce_func is None:
- # reduce_func = lambda x: x.mean(axis=0)
- #
- # # vectorized map functions
- # for key in func.steps.keys():
- # func.steps[key] = _make_vmap(func=func.steps[key],
- # dyn_vars=dyn_vars,
- # rand_vars=rand_vars,
- # in_axes=in_axes[key],
- # out_axes=out_axes[key],
- # axis_name=axis_name,
- # batch_idx=batch_idx[key],
- # reduce_func=reduce_func,
- # f_name=key)
- #
- # return func
-
- if callable(func):
- if auto_infer:
- if dyn_vars is not None:
- dyn_vars = dyn_vars
- elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation
- dyn_vars = func.vars().unique()
- elif hasattr(func, '__self__'):
- if isinstance(func.__self__, BrainPyObject):
- dyn_vars = func.__self__.vars().unique()
-
- if dyn_vars is None:
- return jax.vmap(func,
- in_axes=in_axes,
- out_axes=out_axes,
- axis_name=axis_name)
-
- else:
- if isinstance(dyn_vars, Array):
- dyn_vars = [dyn_vars]
- if isinstance(dyn_vars, (tuple, list)):
- dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)}
- assert isinstance(dyn_vars, dict)
-
- # dynamical variables
- _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector()
- for key, val in dyn_vars.items():
- if isinstance(val, RandomState):
- _rand_vars[key] = val
- else:
- _dyn_vars[key] = val
-
- # in axes
- if in_axes is None:
- in_axes = (None, 0)
- elif isinstance(in_axes, (int, dict)):
- in_axes = (None, 0, in_axes)
- elif isinstance(in_axes, (tuple, list)):
- in_axes = (None, 0) + tuple(in_axes)
- assert isinstance(in_axes, (tuple, list))
-
- # batch size index
- batch_idx = {}
- for key, axes in batch_idx.items():
- for i, axis in enumerate(axes[2:]):
- if axis is not None:
- batch_idx[key] = (i, axis)
- break
- else:
- raise ValueError(f'Found no batch axis: {axes}.')
-
- # out axes
- if out_axes is None:
- out_axes = 0
- elif isinstance(out_axes, (int, dict)):
- out_axes = (out_axes, 0, 0)
- elif isinstance(out_axes, (tuple, list)):
- out_axes = tuple(out_axes) + (0, 0)
- assert isinstance(out_axes, (list, tuple))
-
- # reduce_func
- if reduce_func is None:
- reduce_func = lambda x: x.mean(axis=0)
-
- # jit function
- return _make_vmap(func=func,
- nonbatched_vars=_dyn_vars,
- batched_vars=_rand_vars,
- in_axes=in_axes,
- out_axes=out_axes,
- axis_name=axis_name,
- batch_idx=batch_idx)
-
- else:
- raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable '
- f'function, but we got {type(func)}.')
-
-
-def _device_reshape(x):
- """Reshape an input array in order to broadcast to multiple devices."""
- num_device = jax.local_device_count()
-
- if not hasattr(x, 'ndim'):
- raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to '
- f'parallel, first convert it to a Array, for example np.float(0.5)')
- if x.ndim == 0:
- return np.broadcast_to(x, [num_device])
- if x.shape[0] % num_device != 0:
- raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among '
- f'{num_device} devices, but does not go equally.')
- return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:])
-
-
-def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0,
- out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None,
- axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None):
- @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name,
- static_broadcasted_argnums=static_broadcasted_argnums, devices=devices,
- backend=backend, axis_size=axis_size, donate_argnums=donate_argnums,
- global_arg_shapes=global_arg_shapes)
- def pmapped_func(dyn_data, rand_data, *args, **kwargs):
- dyn_vars.assign(dyn_data)
- rand_vars.assign(rand_data)
- out = func(*args, **kwargs)
- dyn_changes = dyn_vars.dict()
- rand_changes = rand_vars.dict()
- return out, dyn_changes, rand_changes
-
- def call(*args):
- un_replicated = [k for k, v in dyn_vars.items()
- if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))]
- if len(un_replicated):
- raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.'
- f'did you forget to call xx.replicate() on them?')
- _args = []
- for i, x in enumerate(args):
- if i + 2 in static_broadcasted_argnums:
- _args.append(x)
- else:
- _args.append(jax.tree_map(_device_reshape, [x])[0])
- dyn_data = dyn_vars.dict()
- rand_data = rand_vars.dict()
- output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args)
- dyn_vars.assign(dyn_changes)
- rand_vars.assign(rand_changes)
- return jax.tree_map(reduce_func, output)
-
- return change_func_name(name=f_name, f=call) if f_name else call
-
-
-def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(),
- devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None,
- reduce_func=None):
- """Parallel compilation for class objects.
-
- Parallel compile a function or a module to run on multiple devices in parallel.
-
- Parameters
- ----------
- func
- axis_name
- in_axes
- out_axes
- static_broadcasted_argnums
- devices
- backend
- axis_size
- donate_argnums
- global_arg_shapes
-
- Returns
- -------
-
-
- Examples
- --------
-
-
- """
-
- # if isinstance(func, DynamicalSystem):
- # if len(func.steps): # DynamicalSystem has step functions
- #
- # # dynamical variables
- # all_vars = (dyn_vars or func.vars().unique())
- # dyn_vars = ArrayCollector()
- # rand_vars = ArrayCollector()
- # for key, val in all_vars.items():
- # if isinstance(val, RandomState):
- # rand_vars[key] = val
- # else:
- # dyn_vars[key] = val
- #
- # # reduce function
- # if reduce_func is None:
- # reduce_func = jnp.concatenate
- #
- # # static broadcast-ed arguments
- # if static_broadcasted_argnums is None:
- # static_broadcasted_argnums = ()
- # elif isinstance(static_broadcasted_argnums, int):
- # static_broadcasted_argnums = (static_broadcasted_argnums + 2,)
- # elif isinstance(static_broadcasted_argnums, (tuple, list)):
- # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums)
- # assert isinstance(static_broadcasted_argnums, (tuple, list))
- #
- # # jit functions
- # for key in func.steps.keys():
- # step = func.steps[key]
- # func.steps[key] = _make_pmap(dyn_vars=dyn_vars,
- # rand_vars=rand_vars,
- # func=step,
- # axis_name=axis_name,
- # in_axes=in_axes,
- # out_axes=out_axes,
- # static_broadcasted_argnums=static_broadcasted_argnums,
- # devices=devices,
- # backend=backend,
- # axis_size=axis_size,
- # donate_argnums=donate_argnums,
- # global_arg_shapes=global_arg_shapes,
- # reduce_func=reduce_func,
- # f_name=key)
- # return func
-
- if callable(func):
- if dyn_vars is not None:
- dyn_vars = dyn_vars
- elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation
- dyn_vars = func.vars().unique()
- elif hasattr(func, '__self__'):
- if isinstance(func.__self__, BrainPyObject):
- dyn_vars = func.__self__.vars().unique()
-
- if dyn_vars is None:
- return jax.pmap(func,
- axis_name=axis_name,
- in_axes=in_axes,
- out_axes=out_axes,
- static_broadcasted_argnums=static_broadcasted_argnums,
- devices=devices,
- backend=backend,
- axis_size=axis_size,
- donate_argnums=donate_argnums,
- global_arg_shapes=global_arg_shapes)
- else:
- # dynamical variables
- dyn_vars = ArrayCollector()
- rand_vars = ArrayCollector()
- for key, val in dyn_vars.items():
- if isinstance(val, RandomState):
- rand_vars[key] = val
- else:
- dyn_vars[key] = val
-
- # static broadcast-ed arguments
- if static_broadcasted_argnums is None:
- static_broadcasted_argnums = ()
- elif isinstance(static_broadcasted_argnums, int):
- static_broadcasted_argnums = (static_broadcasted_argnums + 2,)
- elif isinstance(static_broadcasted_argnums, (tuple, list)):
- static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums)
- assert isinstance(static_broadcasted_argnums, (tuple, list))
-
- # reduce function
- if reduce_func is None:
- reduce_func = jnp.concatenate
-
- # jit function
- func.__call__ = _make_pmap(dyn_vars=dyn_vars,
- rand_vars=rand_vars,
- func=func,
- axis_name=axis_name,
- in_axes=in_axes,
- out_axes=out_axes,
- static_broadcasted_argnums=static_broadcasted_argnums,
- devices=devices,
- backend=backend,
- axis_size=axis_size,
- donate_argnums=donate_argnums,
- global_arg_shapes=global_arg_shapes,
- reduce_func=reduce_func)
- return func
-
- else:
- raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, '
- f'but we got {type(func)}.')
diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py
index 7ff2949dd..7a04c2488 100644
--- a/brainpy/_src/math/object_transform/tests/test_controls.py
+++ b/brainpy/_src/math/object_transform/tests/test_controls.py
@@ -1,14 +1,11 @@
# -*- coding: utf-8 -*-
-import sys
import tempfile
import unittest
from functools import partial
import jax
-from jax import vmap
-
from absl.testing import parameterized
-from jax._src import test_util as jtu
+from jax import vmap
import brainpy as bp
import brainpy.math as bm
@@ -132,6 +129,69 @@ def update(self):
self.assertTrue(bm.allclose(cls.a, 10.))
+class TestScan(unittest.TestCase):
+ def test1(self):
+ a = bm.Variable(1)
+
+ def f(carray, x):
+ carray += x
+ a.value += 1.
+ return carray, a
+
+ carry, outs = bm.scan(f, bm.zeros(2), bm.arange(10))
+ self.assertTrue(bm.allclose(carry, 45.))
+ expected = bm.arange(1, 11).astype(outs.dtype)
+ expected = bm.expand_dims(expected, axis=-1)
+ self.assertTrue(bm.allclose(outs, expected))
+
+ def test2(self):
+ a = bm.Variable(1)
+
+ def f(carray, x):
+ carray += x
+ a.value += 1.
+ return carray, a
+
+ @bm.jit
+ def f_outer(carray, x):
+ carry, outs = bm.scan(f, carray, x, unroll=2)
+ return carry, outs
+
+ carry, outs = f_outer(bm.zeros(2), bm.arange(10))
+ self.assertTrue(bm.allclose(carry, 45.))
+ expected = bm.arange(1, 11).astype(outs.dtype)
+ expected = bm.expand_dims(expected, axis=-1)
+ self.assertTrue(bm.allclose(outs, expected))
+
+ def test_disable_jit(self):
+ def cumsum(res, el):
+ res = res + el
+ print(res)
+ return res, res # ("carryover", "accumulated")
+
+ a = bm.array([1, 2, 3, 5, 7, 11, 13, 17]).value
+ result_init = 0
+ with jax.disable_jit():
+ final, result = jax.lax.scan(cumsum, result_init, a)
+
+ b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
+ result_init = 0
+ with jax.disable_jit():
+ final, result = bm.scan(cumsum, result_init, b)
+
+ bm.clear_buffer_memory()
+
+ def test_array_aware_of_bp_array(self):
+ def cumsum(res, el):
+ res = bm.asarray(res + el)
+ return res, res # ("carryover", "accumulated")
+
+ b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
+ result_init = 0
+ with jax.disable_jit():
+ final, result = bm.scan(cumsum, result_init, b)
+
+
class TestCond(unittest.TestCase):
def test1(self):
bm.random.seed(1)
@@ -208,6 +268,27 @@ def f2():
self.assertTrue(f2().size == 200)
+ def test_grad1(self):
+ def F2(x):
+ return bm.ifelse(conditions=(x >= 10,),
+ branches=[lambda x: x,
+ lambda x: x ** 2, ],
+ operands=x)
+
+ self.assertTrue(bm.grad(F2)(9.0) == 18.)
+ self.assertTrue(bm.grad(F2)(11.0) == 1.)
+
+ def test_grad2(self):
+ def F3(x):
+ return bm.ifelse(conditions=(x >= 10, x >= 0),
+ branches=[lambda x: x,
+ lambda x: x ** 2,
+ lambda x: x ** 4, ],
+ operands=x)
+
+ self.assertTrue(bm.grad(F3)(9.0) == 18.)
+ self.assertTrue(bm.grad(F3)(11.0) == 1.)
+
class TestWhile(unittest.TestCase):
def test1(self):
diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py
index 7b519590a..632c6d79e 100644
--- a/brainpy/_src/math/object_transform/tools.py
+++ b/brainpy/_src/math/object_transform/tools.py
@@ -132,19 +132,65 @@ def evaluate_dyn_vars_with_cache(
return stack
+def _partial_fun2(
+ fun: Callable,
+ args: tuple,
+ kwargs: dict,
+ static_argnums: Sequence[int] = (),
+ static_argnames: Sequence[str] = ()
+):
+ num_args = len(args)
+
+ # arguments
+ static_args = dict()
+ dyn_args = []
+ dyn_arg_ids = dict()
+ static_argnums = list(static_argnums)
+ dyn_i = 0
+ for i in range(num_args):
+ if i in static_argnums:
+ static_argnums.remove(i)
+ static_args[i] = args[i]
+ else:
+ dyn_args.append(args[i])
+ dyn_arg_ids[i] = dyn_i
+ dyn_i += 1
+ if len(static_argnums) > 0:
+ raise ValueError(f"Invalid static_argnums: {static_argnums}")
+
+ # keyword arguments
+ static_kwargs, dyn_kwargs = {}, {}
+ for k, arg in kwargs.items():
+ if k in static_argnames:
+ static_kwargs[k] = arg
+ else:
+ dyn_kwargs[k] = arg
+ del args, kwargs, static_argnums, static_argnames
+
+ @wraps(fun)
+ def new_fun(*dynargs, **dynkwargs):
+ return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)],
+ **static_kwargs,
+ **dynkwargs)
+
+ return new_fun, dyn_args, dyn_kwargs
+
+
def eval_shape(
fun: Callable,
*args,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
+ with_stack: bool = False,
**kwargs
):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
Args:
fun: The callable function.
- *args:
- **kwargs:
+ *args: The positional arguments.
+ **kwargs: The keyword arguments.
+ with_stack: Whether evaluate the function within a local variable stack.
static_argnums: The static argument indices.
static_argnames: The static argument names.
@@ -153,21 +199,30 @@ def eval_shape(
"""
# reorganize the function
if len(static_argnums) or len(static_argnames):
- f2, args, kwargs = _partial_fun(fun, args, kwargs,
- static_argnums=static_argnums,
- static_argnames=static_argnames)
+ f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
else:
- f2, args, kwargs = fun, args, kwargs
+ f2 = fun
# evaluate the function
fun_in_eval_shape.append(fun)
try:
- with jax.ensure_compile_time_eval():
+ if with_stack:
with VariableStack() as stack:
if len(fun_in_eval_shape) > 1:
- returns = fun(*args, **kwargs)
+ returns = f2(*args, **kwargs)
else:
- returns = jax.eval_shape(fun, *args, **kwargs)
+ returns = jax.eval_shape(f2, *args, **kwargs)
+ else:
+ stack = None
+ if len(fun_in_eval_shape) > 1:
+ returns = f2(*args, **kwargs)
+ else:
+ returns = jax.eval_shape(f2, *args, **kwargs)
finally:
fun_in_eval_shape.pop()
- return stack, returns
+ del f2
+ if with_stack:
+ return stack, returns
+ else:
+ return returns
+
diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py
index 5014da0bf..b7babae8d 100644
--- a/brainpy/_src/math/object_transform/variables.py
+++ b/brainpy/_src/math/object_transform/variables.py
@@ -1,4 +1,3 @@
-from contextlib import contextmanager
from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple
import jax
@@ -190,6 +189,14 @@ def remove_by_id(self, *ids, error_when_absent=False):
remove_var_by_id = remove_by_id
+ @classmethod
+ def num_of_stack(self):
+ return len(var_stack_list)
+
+ @classmethod
+ def is_first_stack(self):
+ return len(var_stack_list) == 0
+
def __enter__(self) -> 'VariableStack':
self.collect_values() # recollect the original value of each variable
var_stack_list.append(self)
@@ -210,42 +217,6 @@ def __add__(self, other: dict):
var_stack_list: List[VariableStack] = []
-transform_stack: List[Callable] = []
-
-
-@contextmanager
-def new_transform(transform: Any):
- transform_stack.append(transform)
- try:
- yield
- finally:
- transform_stack.pop()
-
-
-def outermost_stack():
- if len(var_stack_list):
- return var_stack_list[0]
- else:
- return None
-
-
-def outermost_transform():
- if len(transform_stack):
- return transform_stack[0]
- else:
- return None
-
-
-def current_transform_number():
- return len(transform_stack)
-
-
-def _stack_add_read(var: 'Variable'):
- pass
-
-
-def _stack_add_write(var: 'Variable'):
- pass
@register_pytree_node_class
diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py
index 4d5acf26a..01f77dbca 100644
--- a/brainpy/_src/math/op_register/__init__.py
+++ b/brainpy/_src/math/op_register/__init__.py
@@ -2,4 +2,6 @@
from .numba_approach import (CustomOpByNumba,
register_op_with_numba,
compile_cpu_signature_with_numba)
+from .taichi_aot_based import clean_caches, check_kernels_count
+from .base import XLACustomOp
from .utils import register_general_batching
diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py
new file mode 100644
index 000000000..342093ea2
--- /dev/null
+++ b/brainpy/_src/math/op_register/ad_support.py
@@ -0,0 +1,51 @@
+import functools
+from functools import partial
+
+from jax import tree_util
+from jax.core import Primitive
+from jax.interpreters import ad
+
+__all__ = [
+ 'defjvp',
+]
+
+
+def defjvp(primitive, *jvp_rules):
+ """Define JVP rules for any JAX primitive.
+
+ This function is similar to ``jax.interpreters.ad.defjvp``.
+ However, the JAX one only supports primitive with ``multiple_results=False``.
+ ``brainpy.math.defjvp`` enables to define the independent JVP rule for
+ each input parameter no matter ``multiple_results=False/True``.
+
+ For examples, please see ``test_ad_support.py``.
+
+ Args:
+ primitive: Primitive, XLACustomOp.
+ *jvp_rules: The JVP translation rule for each primal.
+ """
+ assert isinstance(primitive, Primitive)
+ if primitive.multiple_results:
+ ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
+ else:
+ ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
+
+
+def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
+ assert primitive.multiple_results
+ val_out = tuple(primitive.bind(*primals, **params))
+ tree = tree_util.tree_structure(val_out)
+ tangents_out = []
+ for rule, t in zip(jvp_rules, tangents):
+ if rule is not None and type(t) is not ad.Zero:
+ r = tuple(rule(t, *primals, **params))
+ tangents_out.append(r)
+ assert tree_util.tree_structure(r) == tree
+ return val_out, functools.reduce(_add_tangents,
+ tangents_out,
+ tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
+
+
+def _add_tangents(xs, ys):
+ return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
+
diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py
index 31aef70d6..1824ac911 100644
--- a/brainpy/_src/math/op_register/base.py
+++ b/brainpy/_src/math/op_register/base.py
@@ -8,17 +8,17 @@
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
-# if jax.__version__ >= '0.4.16':
-# from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
-# else:
-# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
-from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
-from .taichi_aot_based import (register_taichi_cpu_translation_rule,
- register_taichi_gpu_translation_rule,
- encode_md5,
- _preprocess_kernel_call_cpu,
- get_source_with_dependencies)
+
+if jax.__version__ >= '0.4.16':
+ from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
+ from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
+ register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
+else:
+ from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
+ from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
+ register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .utils import register_general_batching
+from brainpy._src.math.op_register.ad_support import defjvp
__all__ = [
'XLACustomOp',
@@ -64,8 +64,8 @@ class XLACustomOp(BrainPyObject):
>>>
>>> # option 2
>>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun,
- >>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
- >>> jax.ShapeDtypeStruct(1000, dtype=np.float32)])
+ >>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype),
+ >>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)])
>>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000))
Args:
@@ -74,7 +74,7 @@ class XLACustomOp(BrainPyObject):
batching_translation: Callable. The batching translation rule of JAX.
jvp_translation: Callable. The JVP translation rule of JAX.
transpose_translation: Callable. The transpose translation rule of JAX.
- outs: optional, sequence of `ShapeDtype`. The output information.
+ outs: optional. The output information.
name: str. The primitive name.
"""
@@ -85,7 +85,7 @@ def __init__(
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
- outs: Optional[Sequence[ShapeDtype]] = None,
+ outs: Optional[Callable] = None,
name: str = None,
):
super().__init__(name)
@@ -99,8 +99,6 @@ def __init__(
self.primitive.multiple_results = True
# abstract evaluation
- if outs is not None:
- outs = tuple([_transform_to_shapedarray(o) for o in outs])
self.outs = outs
self.primitive.def_abstract_eval(_abstract_eval)
self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
@@ -139,13 +137,15 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation
- def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None):
+ def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
- outs = self.outs
+ if self.outs is None:
+ raise ValueError('The output information is not defined.')
+ outs = self.outs(*ins, **kwargs)
assert outs is not None
outs = tuple([_transform_to_shapedarray(o) for o in outs])
ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
- return self.primitive.bind(*ins, outs=outs)
+ return self.primitive.bind(*ins, outs=outs, **kwargs)
def def_abstract_eval(self, fun):
"""Define the abstract evaluation function.
@@ -171,6 +171,14 @@ def def_jvp_rule(self, fun):
"""
ad.primitive_jvps[self.primitive] = fun
+ def defjvp(self, *jvp_rules):
+ """Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``, but supports the Primitive with multiple results.
+
+ Args:
+ jvp_rules: The JVP rules.
+ """
+ defjvp(self.primitive, *jvp_rules)
+
def def_transpose_rule(self, fun):
"""Define the transpose rule.
@@ -218,5 +226,3 @@ def _transform_to_array(a):
def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)
-
-
diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py
index 76362215e..cc2ce5b4c 100644
--- a/brainpy/_src/math/op_register/numba_approach/__init__.py
+++ b/brainpy/_src/math/op_register/numba_approach/__init__.py
@@ -6,7 +6,7 @@
from typing import Union, Sequence
import numba
-from jax import core
+import jax
from jax.interpreters import xla, batching, ad
from jax.tree_util import tree_map
from numba.core.dispatcher import Dispatcher
@@ -40,8 +40,8 @@ class CustomOpByNumba(BrainPyObject):
The function to make the concrete computation. This function receives inputs,
and returns outputs. For example:
- >>> def con_compute(inp1, inp2, inp3, ...):
- >>> return out1, out2
+ >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...):
+ >>> pass
"""
def __init__(
@@ -86,7 +86,7 @@ def __call__(self, *args, **kwargs):
def register_op_with_numba(
op_name: str,
cpu_func: Callable,
- out_shapes: Union[Callable, core.ShapedArray, Sequence[core.ShapedArray]],
+ out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]],
gpu_func_translation: Callable = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
@@ -130,12 +130,19 @@ def register_op_with_numba(
A JAX Primitive object.
"""
+ if jax.__version__ > '0.4.23':
+ raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are '
+ f'only supported in JAX version <= 0.4.23. \n'
+ f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. '
+ f'For more information, please refer to the documentation: '
+ f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.')
+
if out_shapes is None:
raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or '
'a sequence of `ShapedArray`. If it is a function, it takes as input the argument '
'shapes and dtypes and should return correct output shapes of `ShapedArray`.')
- prim = core.Primitive(op_name)
+ prim = jax.core.Primitive(op_name)
prim.multiple_results = multiple_results
# user defined function
@@ -149,12 +156,12 @@ def abs_eval_rule(*input_shapes, **info):
else:
shapes = out_shapes
- if isinstance(shapes, core.ShapedArray):
+ if isinstance(shapes, jax.core.ShapedArray):
assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data."
elif isinstance(shapes, (tuple, list)):
assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data."
for elem in shapes:
- if not isinstance(elem, core.ShapedArray):
+ if not isinstance(elem, jax.core.ShapedArray):
raise ValueError(f'Elements in "out_shapes" must be instances of '
f'jax.abstract_arrays.ShapedArray, but we got '
f'{type(elem)}: {elem}')
diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py
index fb51b5dbf..fb76aed24 100644
--- a/brainpy/_src/math/op_register/numba_based.py
+++ b/brainpy/_src/math/op_register/numba_based.py
@@ -16,6 +16,10 @@
'register_numba_mlir_cpu_translation_rule',
]
+
+# [void* pointer,
+# const char *name,
+# PyCapsule_Destructor destructor]
ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
@@ -100,6 +104,7 @@ def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs):
def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False):
+ # do not support after jax >= 0.4.24
xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule,
cpu_kernel,
debug)
@@ -124,38 +129,44 @@ def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs):
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
for i in range(len(input_shapes))]
- args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
- for i in range(len(output_shapes))]
+ if len(output_shapes) > 1:
+ args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
+ for i in range(len(output_shapes))]
+ sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
+ else:
+ args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
+ sig = types.void(types.voidptr, types.CPointer(types.voidptr))
args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
code_string = '''
- def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
+def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_in}
{args_out}
func_to_call({args_call})
'''.format(args_in="\n ".join(args_in),
args_out="\n ".join(args_out),
args_call=", ".join(args_call))
- if debug: print(code_string)
+ if debug:
+ print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']
# register
- xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)))(new_f)
+ xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")
# call
- call = custom_call(call_target_name=target_name,
- operands=list(ins),
- operand_layouts=list(input_layouts),
- result_layouts=list(output_layouts),
- result_types=list(result_types)).results
- return call
+ return custom_call(
+ call_target_name=target_name,
+ operands=ins,
+ operand_layouts=list(input_layouts),
+ result_layouts=list(output_layouts),
+ result_types=list(result_types),
+ has_side_effect=False,
+ ).results
def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False):
rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
mlir.register_lowering(primitive, rule, platform='cpu')
-
-
diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py
index 06d0508a1..7fac4452d 100644
--- a/brainpy/_src/math/op_register/taichi_aot_based.py
+++ b/brainpy/_src/math/op_register/taichi_aot_based.py
@@ -1,18 +1,25 @@
+import contextlib
import hashlib
import inspect
+import io
import os
import pathlib
+import platform
import re
+import shutil
from functools import partial, reduce
from typing import Any, Sequence
import jax.core
import numpy as np
-from jax.interpreters import xla
+from jax.interpreters import xla, mlir
from jax.lib import xla_client
+from jaxlib.hlo_helpers import custom_call
+from brainpy._src.dependency_check import (import_taichi,
+ import_brainpylib_cpu_ops,
+ import_brainpylib_gpu_ops)
from .utils import _shape_to_layout
-from brainpy._src.dependency_check import import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops
### UTILS ###
@@ -35,34 +42,71 @@ def encode_md5(source: str) -> str:
return md5.hexdigest()
+# check kernels count
+def check_kernels_count() -> int:
+ if not os.path.exists(kernels_aot_path):
+ return 0
+ kernels_count = 0
+ dir1 = os.listdir(kernels_aot_path)
+ for i in dir1:
+ dir2 = os.listdir(os.path.join(kernels_aot_path, i))
+ kernels_count += len(dir2)
+ return kernels_count
+
+# clean caches
+def clean_caches(kernels_name: list[str]=None):
+ if kernels_name is None:
+ if not os.path.exists(kernels_aot_path):
+ raise FileNotFoundError("The kernels cache folder does not exist. \
+ Please define a kernel using `taichi.kernel` \
+ and customize the operator using `bm.XLACustomOp` \
+ before calling the operator.")
+ shutil.rmtree(kernels_aot_path)
+ print('Clean all kernel\'s cache successfully')
+ return
+ for kernel_name in kernels_name:
+ try:
+ shutil.rmtree(os.path.join(kernels_aot_path, kernel_name))
+ except FileNotFoundError:
+ raise FileNotFoundError(f'Kernel {kernel_name} does not exist.')
+ print('Clean kernel\'s cache successfully')
+# TODO
+# not a very good way
# get source with dependencies
def get_source_with_dependencies(func, visited=None):
if visited is None:
visited = set()
source = inspect.getsource(func)
-
if func in visited:
return ''
visited.add(func)
-
module = inspect.getmodule(func)
-
dependent_funcs = re.findall(r'(\w+)\(', source)
for func_name in dependent_funcs:
dependent_func = getattr(module, func_name, None)
if callable(dependent_func):
source += get_source_with_dependencies(dependent_func, visited)
-
return source
+# check if Metal is supported
+def is_metal_supported():
+ # first check if we are on macOS
+ if platform.system() != 'Darwin':
+ return False
+ if platform.processor() != 'arm':
+ return False
+ return True
+
+
### VARIABLES ###
home_path = get_home_dir()
kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels')
+is_metal_device = is_metal_supported()
# check if a kernel exists in the database
@@ -107,7 +151,9 @@ def _array_to_field(dtype, shape) -> Any:
elif dtype == np.float64:
dtype = ti.float64
else:
- raise TypeError
+ raise NotImplementedError(f'Currently we do not support dtype {dtype} in Taichi. '
+ f'If you think it is necessary, please open an issue at '
+ f'https://github.com/brainpy/BrainPy/issues/new')
return ti.field(dtype=dtype, shape=shape)
@@ -122,18 +168,26 @@ def _build_kernel(
ti = import_taichi()
# init arch
- arch = None
if device == 'cpu':
- arch = ti.x64
+ if is_metal_device:
+ arch = ti.arm64
+ device = 'arm64'
+ else:
+ arch = ti.x64
elif device == 'gpu':
arch = ti.cuda
-
- ti.init(arch=arch)
+ else:
+ raise ValueError(f'Unknown device: {device}')
+ with contextlib.redirect_stdout(io.StringIO()):
+ ti.init(arch=arch)
# check arch is available
if ti.lang.impl.current_cfg().arch != arch:
raise RuntimeError(f"Arch {arch} is not available")
+ # get kernel name
+ kernel_name = kernel.__name__
+
# replace the name of the func
kernel.__name__ = f'taichi_kernel_{device}'
@@ -153,6 +207,9 @@ def _build_kernel(
mod.add_kernel(kernel, template_args=template_args_dict)
mod.save(kernel_path)
+ # rename kernel name
+ kernel.__name__ = kernel_name
+
### KERNEL CALL PREPROCESS ###
@@ -229,7 +286,7 @@ def _preprocess_kernel_call_cpu(
return in_out_info
-def preprocess_kernel_call_gpu(
+def _preprocess_kernel_call_gpu(
source_md5_encode: str,
ins: dict,
outs: dict,
@@ -276,11 +333,18 @@ def preprocess_kernel_call_gpu(
return opaque
+
+
+
def _XlaOp_to_ShapedArray(c, xla_op):
xla_op = c.get_shape(xla_op)
return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type())
+def _mlir_to_ShapedArray(c, op):
+ return op
+
+
def _kernel_to_code(kernel, abs_ins, abs_outs, platform):
codes = f'[taichi {platform} kernel]\n' + get_source_with_dependencies(kernel)
codes += '\n[ins]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_ins]))
@@ -288,17 +352,16 @@ def _kernel_to_code(kernel, abs_ins, abs_outs, platform):
return codes
-def _compile_kernel(kernel, c, platform, *ins, **kwargs):
+def _compile_kernel(abs_ins, kernel, platform: str, **kwargs):
# input and output abstract information
abs_outs = kwargs['outs']
- abs_ins = [_XlaOp_to_ShapedArray(c, v) for v in ins]
# kernel to code
codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform)
- source_md5_encode = encode_md5(codes)
+ source_md5_encode = os.path.join(kernel.__name__, encode_md5(codes))
# create ins, outs dict from kernel's args
- in_num = len(ins)
+ in_num = len(abs_ins)
names = tuple(inspect.signature(kernel).parameters.keys())
in_names, out_names = names[:in_num], names[in_num:]
ins_dict = {key: (abs_ins[i].dtype, abs_ins[i].shape) for i, key in enumerate(in_names)}
@@ -309,13 +372,16 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):
try:
_build_kernel(source_md5_encode, kernel, ins_dict, outs_dict, platform)
except Exception as e:
- os.removedirs(os.path.join(kernels_aot_path, source_md5_encode))
+ try:
+ os.removedirs(os.path.join(kernels_aot_path, source_md5_encode))
+ except Exception:
+ raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e
raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e
# returns
if platform in ['gpu', 'cuda']:
import_brainpylib_gpu_ops()
- opaque = preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
+ opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
return opaque
elif platform == 'cpu':
import_brainpylib_cpu_ops()
@@ -325,12 +391,25 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):
raise ValueError(f'Unknown platform: {platform}')
-def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs):
- in_out_info = _compile_kernel(kernel, c, 'cpu', *ins, **kwargs)
+def _get_abs_ins(c, ins):
+ abs_ins = []
+ for v in ins:
+ xla_op = c.get_shape(v)
+ abs_ins.append(jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type()))
+ return abs_ins
+
+
+def _taichi_xla_cpu_translation_rule(kernel, c, *ins, **kwargs):
+ in_out_info = _compile_kernel(_get_abs_ins(c, ins), kernel, 'cpu', **kwargs)
ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins)
+ if is_metal_device:
+ fn = b'taichi_kernel_aot_call_cpu_arm64'
+ else:
+ fn = b'taichi_kernel_aot_call_cpu'
+
return xla_client.ops.CustomCallWithLayout(
c,
- b'taichi_kernel_aot_call_cpu',
+ fn,
operands=ins,
operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
shape_with_layout=xla_client.Shape.tuple_shape(
@@ -340,8 +419,8 @@ def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs):
)
-def _taichi_gpu_translation_rule(kernel, c, *ins, **kwargs):
- opaque = _compile_kernel(kernel, c, 'gpu', *ins, **kwargs)
+def _taichi_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
+ opaque = _compile_kernel(_get_abs_ins(c, ins), kernel, 'gpu', **kwargs)
return xla_client.ops.CustomCallWithLayout(
c,
b'taichi_kernel_aot_call_gpu',
@@ -355,9 +434,61 @@ def _taichi_gpu_translation_rule(kernel, c, *ins, **kwargs):
)
-def register_taichi_cpu_translation_rule(primitive, cpu_kernel):
- xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_cpu_translation_rule, cpu_kernel)
+def register_taichi_aot_xla_cpu_translation_rule(primitive, cpu_kernel):
+ xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_xla_cpu_translation_rule, cpu_kernel)
+
+
+def register_taichi_aot_xla_gpu_translation_rule(primitive, gpu_kernel):
+ xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_xla_gpu_translation_rule, gpu_kernel)
+
+
+def _taichi_mlir_cpu_translation_rule(kernel, c, *ins, **kwargs):
+ in_out_info = _compile_kernel(c.avals_in, kernel, 'cpu', **kwargs)
+ ins = [mlir.ir_constant(v) for v in in_out_info] + list(ins)
+ input_layouts = [_shape_to_layout(arr.shape) for arr in in_out_info] + [_shape_to_layout(a.shape) for a in c.avals_in]
+ output_layouts = tuple([_shape_to_layout(out.shape) for out in c.avals_out])
+ result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
+ if is_metal_device:
+ if len(output_layouts) == 1:
+ fn = 'taichi_kernel_aot_call_cpu_arm64_single_result'
+ else:
+ fn = 'taichi_kernel_aot_call_cpu_arm64'
+ else:
+ if len(output_layouts) == 1:
+ fn = 'taichi_kernel_aot_call_cpu_single_result'
+ else:
+ fn = 'taichi_kernel_aot_call_cpu'
+ return custom_call(
+ call_target_name=fn,
+ operands=ins,
+ operand_layouts=list(input_layouts),
+ result_layouts=list(output_layouts),
+ result_types=list(result_types),
+ has_side_effect=False,
+ ).results
+
+
+def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
+ opaque = _compile_kernel(c.avals_in, kernel, 'gpu', **kwargs)
+ input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
+ result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
+ output_layouts = [_shape_to_layout(out.shape) for out in c.avals_out]
+ return custom_call(
+ call_target_name='taichi_kernel_aot_call_gpu',
+ operands=ins,
+ operand_layouts=list(input_layouts),
+ result_layouts=list(output_layouts),
+ result_types=list(result_types),
+ backend_config=opaque,
+ has_side_effect=False,
+ ).results
+
+
+def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel):
+ rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel)
+ mlir.register_lowering(primitive, rule, platform='cpu')
-def register_taichi_gpu_translation_rule(primitive, gpu_kernel):
- xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_gpu_translation_rule, gpu_kernel)
+def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel):
+ rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel)
+ mlir.register_lowering(primitive, rule, platform='gpu')
diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py
new file mode 100644
index 000000000..24f010a12
--- /dev/null
+++ b/brainpy/_src/math/op_register/tests/test_ad_support.py
@@ -0,0 +1,138 @@
+from typing import Tuple
+
+import jax
+import numba
+from jax import core
+from jax import numpy as jnp
+from jax.interpreters import ad
+
+import brainpy as bp
+import brainpy.math as bm
+
+bm.set_platform('cpu')
+
+
+def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ):
+ data = jnp.atleast_1d(bm.as_jax(data))
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ vector = bm.as_jax(vector)
+ if vector.dtype == jnp.bool_:
+ vector = bm.as_jax(vector, dtype=data.dtype)
+ outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)]
+ if transpose:
+ return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
+ else:
+ return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
+
+
+@numba.njit(fastmath=True)
+def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val):
+ res_val.fill(0)
+ if values.shape[0] == 1:
+ values = values[0]
+ for row_i in range(vector.shape[0]):
+ v = vector[row_i]
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ res_val[col_indices[j]] += values * v
+ else:
+ for row_i in range(vector.shape[0]):
+ v = vector[row_i]
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ res_val[col_indices[j]] += v * values[j]
+
+
+@numba.njit(fastmath=True, parallel=True, nogil=True)
+def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val):
+ res_val.fill(0)
+ # csr mat @ vec
+ if values.shape[0] == 1:
+ values = values[0]
+ for row_i in numba.prange(res_val.shape[0]):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += values * vector[col_indices[j]]
+ res_val[row_i] = r
+ else:
+ for row_i in numba.prange(res_val.shape[0]):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += values[j] * vector[col_indices[j]]
+ res_val[row_i] = r
+
+
+def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
+ return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
+
+
+def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
+ return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose)
+
+
+def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs):
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
+
+ ct = ct[0]
+ if ad.is_undefined_primal(vector):
+ ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose)
+ return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+
+ else:
+ if type(ct) is ad.Zero:
+ ct_data = ad.Zero(data)
+ else:
+ if data.aval.shape[0] == 1: # scalar
+ ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
+ ct_data = jnp.inner(ct, ct_data)
+ else: # heterogeneous values
+ row, col = bm.sparse.csr_to_coo(indices, indptr)
+ ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
+ return ct_data, indices, indptr, vector
+
+
+prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp)
+prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
+prim_trans.def_transpose_rule(_csrmv_cusparse_transpose)
+
+prim = bm.XLACustomOp(_csr_matvec_numba_imp)
+prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
+prim.def_transpose_rule(_csrmv_cusparse_transpose)
+
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)
+ return r.sum()
+
+ return func
+
+
+def try_a_trial(transpose, shape):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ heter_data = rng.random(indices.shape)
+ heter_data = bm.as_jax(heter_data)
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+
+ r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs)), argnums=(0, 3))(
+ heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))(
+ heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ print(r5)
+ print(r6)
+ assert bm.allclose(r5[0], r6[0])
+ assert bm.allclose(r5[1], r6[1][0])
+
+
+def test():
+ transposes = [True, False]
+ shapes = [(100, 200), (10, 1000), (2, 2000)]
+
+ for transpose in transposes:
+ for shape in shapes:
+ try_a_trial(transpose, shape)
diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py
index dd0a38dbf..968155ef9 100644
--- a/brainpy/_src/math/op_register/tests/test_numba_based.py
+++ b/brainpy/_src/math/op_register/tests/test_numba_based.py
@@ -1,31 +1,32 @@
-import jax.core
-import brainpy.math as bm
-import numba
-
-
-@numba.njit(fastmath=True)
-def numba_event_csrmv(weight, indices, vector, outs):
- outs.fill(0)
- weight = weight[()] # 0d
- for row_i in range(vector.shape[0]):
- if vector[row_i]:
- for j in indices[row_i]:
- outs[j] += weight
-
-
-prim = bm.XLACustomOp(numba_event_csrmv)
-
-
-def call(s=100):
- indices = bm.random.randint(0, s, (s, 80))
- vector = bm.random.rand(s) < 0.1
- out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
- print(out[0].shape)
-
-
-def test_event_ELL():
- call(1000)
- call(100)
- bm.clear_buffer_memory()
-
-
+import jax.core
+import brainpy.math as bm
+import numba
+
+bm.set_platform('cpu')
+
+@numba.njit(fastmath=True)
+def numba_event_csrmv(weight, indices, vector, outs):
+ outs.fill(0)
+ weight = weight[()] # 0d
+ for row_i in range(vector.shape[0]):
+ if vector[row_i]:
+ for j in indices[row_i]:
+ outs[j] += weight
+
+
+prim = bm.XLACustomOp(numba_event_csrmv)
+
+
+def call(s=100):
+ indices = bm.random.randint(0, s, (s, 80))
+ vector = bm.random.rand(s) < 0.1
+ out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
+ print(out[0].shape)
+
+
+def test_event_ELL():
+ call(1000)
+ call(100)
+ bm.clear_buffer_memory()
+
+
diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py
index 14ee77a81..03023754c 100644
--- a/brainpy/_src/math/op_register/tests/test_taichi_based.py
+++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py
@@ -1,55 +1,48 @@
import jax
import jax.numpy as jnp
-import taichi as taichi
-import pytest
-import platform
+import taichi as ti
import brainpy.math as bm
bm.set_platform('cpu')
-if not platform.platform().startswith('Windows'):
- pytest.skip(allow_module_level=True)
-
-
-# @ti.kernel
-# def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
-# vector: ti.types.ndarray(ndim=1),
-# weight: ti.types.ndarray(ndim=1),
-# out: ti.types.ndarray(ndim=1)):
-# weight_0 = weight[0]
-# num_rows, num_cols = indices.shape
-# ti.loop_config(serialize=True)
-# for i in range(num_rows):
-# if vector[i]:
-# for j in range(num_cols):
-# out[indices[i, j]] += weight_0
-
-@taichi.func
-def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32:
+
+@ti.func
+def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
return weight[0]
-@taichi.func
-def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32):
+@ti.func
+def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
out[index] += weight_val
-@taichi.kernel
-def event_ell_cpu(indices: taichi.types.ndarray(ndim=2),
- vector: taichi.types.ndarray(ndim=1),
- weight: taichi.types.ndarray(ndim=1),
- out: taichi.types.ndarray(ndim=1)):
+@ti.kernel
+def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
+ vector: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
- taichi.loop_config(serialize=True)
+ ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)
+@ti.kernel
+def event_ell_gpu(indices: ti.types.ndarray(ndim=2),
+ vector: ti.types.ndarray(ndim=1),
+ weight: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ weight_val = get_weight(weight)
+ num_rows, num_cols = indices.shape
+ for i in range(num_rows):
+ if vector[i]:
+ for j in range(num_cols):
+ update_output(out, indices[i, j], weight_val)
-prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
+prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)
def test_taichi_op_register():
diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py
new file mode 100644
index 000000000..1bebcdafe
--- /dev/null
+++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py
@@ -0,0 +1,54 @@
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import platform
+import pytest
+import taichi
+
+if not platform.platform().startswith('Windows'):
+ pytest.skip(allow_module_level=True)
+
+@taichi.func
+def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32:
+ return weight[0]
+
+
+@taichi.func
+def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32):
+ out[index] += weight_val
+
+@taichi.kernel
+def event_ell_cpu(indices: taichi.types.ndarray(ndim=2),
+ vector: taichi.types.ndarray(ndim=1),
+ weight: taichi.types.ndarray(ndim=1),
+ out: taichi.types.ndarray(ndim=1)):
+ weight_val = get_weight(weight)
+ num_rows, num_cols = indices.shape
+ taichi.loop_config(serialize=True)
+ for i in range(num_rows):
+ if vector[i]:
+ for j in range(num_cols):
+ update_output(out, indices[i, j], weight_val)
+
+prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
+
+def test_taichi_clean_cache():
+ s = 1000
+ indices = bm.random.randint(0, s, (s, 1000))
+ vector = bm.random.rand(s) < 0.1
+ weight = bm.array([1.0])
+
+ out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
+
+ out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
+
+ print(out)
+ bm.clear_buffer_memory()
+
+ print('kernels: ', bm.check_kernels_count())
+
+ bm.clean_caches()
+
+ print('kernels: ', bm.check_kernels_count())
+
+# test_taichi_clean_cache()
\ No newline at end of file
diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py
index 31e97df88..94aeebb16 100644
--- a/brainpy/_src/math/others.py
+++ b/brainpy/_src/math/others.py
@@ -1,20 +1,27 @@
# -*- coding: utf-8 -*-
-from typing import Optional
+from typing import Optional, Union
+import jax
import jax.numpy as jnp
from jax.tree_util import tree_map
from brainpy import check, tools
+from .compat_numpy import fill_diagonal
from .environment import get_dt, get_int
+from .interoperability import as_jax
from .ndarray import Array
-from .compat_numpy import fill_diagonal
__all__ = [
'shared_args_over_time',
'remove_diag',
'clip_by_norm',
+ 'exprel',
+ 'is_float_type',
+ # 'reduce',
+ 'add_axis',
+ 'add_axes',
]
@@ -82,3 +89,56 @@ def f(l):
return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm)
return tree_map(f, t)
+
+
+def _exprel(x, threshold):
+ def true_f(x):
+ x2 = x * x
+ return 1. + x / 2. + x2 / 6. + x2 * x / 24.0 # + x2 * x2 / 120.
+
+ def false_f(x):
+ return (jnp.exp(x) - 1) / x
+
+ # return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x)
+ return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x)
+
+
+def exprel(x, threshold: float = None):
+ """Relative error exponential, ``(exp(x) - 1)/x``.
+
+ When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can
+ suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of
+ precision that occurs when ``x`` is near zero.
+
+ Args:
+ x: ndarray. Input array. ``x`` must contain real numbers.
+ threshold: float.
+
+ Returns:
+ ``(exp(x) - 1)/x``, computed element-wise.
+ """
+ x = as_jax(x)
+ if threshold is None:
+ if hasattr(x, 'dtype') and x.dtype == jnp.float64:
+ threshold = 1e-8
+ else:
+ threshold = 1e-5
+ return _exprel(x, threshold)
+
+
+def is_float_type(x: Union[Array, jax.Array]):
+ return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
+
+
+def add_axis(x: Union[Array, jax.Array], new_position: int):
+ x = as_jax(x)
+ return jnp.expand_dims(x, new_position)
+
+
+def add_axes(x: Union[Array, jax.Array], n_axes, pos2len):
+ x = as_jax(x)
+ repeats = [1] * n_axes
+ for axis_position, axis_length in pos2len.items():
+ x = add_axis(x, axis_position)
+ repeats[axis_position] = axis_length
+ return jnp.tile(x, repeats)
diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py
index 964d3f51e..d0f74bf23 100644
--- a/brainpy/_src/math/random.py
+++ b/brainpy/_src/math/random.py
@@ -4,16 +4,16 @@
from collections import namedtuple
from functools import partial
from operator import index
-from typing import Optional, Union
+from typing import Optional, Union, Sequence
import jax
import numpy as np
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
+from jax._src.array import ArrayImpl
from jax.experimental.host_callback import call
from jax.tree_util import register_pytree_node_class
-from jax._src.array import ArrayImpl
-from brainpy.check import jit_error
+from brainpy.check import jit_error_checking, jit_error_checking_no_args
from .compat_numpy import shape
from .environment import get_int
from .ndarray import Array, _return
@@ -40,6 +40,8 @@
'rand_like', 'randint_like', 'randn_like',
]
+JAX_RAND_KEY = jax.Array
+
def _formalize_key(key):
if isinstance(key, int):
@@ -60,7 +62,7 @@ def _size2shape(size):
elif isinstance(size, (tuple, list)):
return tuple(size)
else:
- return (size, )
+ return (size,)
def _check_shape(name, shape, *param_shapes):
@@ -565,12 +567,16 @@ def split_keys(self, n):
# random functions #
# ---------------- #
- def rand(self, *dn, key=None):
+ def rand(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.uniform(key, shape=dn, minval=0., maxval=1.)
return _return(r)
- def randint(self, low, high=None, size=None, dtype=int, key=None):
+ def randint(self,
+ low,
+ high=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ dtype=int, key: Optional[Union[int, JAX_RAND_KEY]] = None):
dtype = get_int() if dtype is None else dtype
low = _as_jax_array(low)
high = _as_jax_array(high)
@@ -588,7 +594,11 @@ def randint(self, low, high=None, size=None, dtype=int, key=None):
minval=low, maxval=high, dtype=dtype)
return _return(r)
- def random_integers(self, low, high=None, size=None, key=None):
+ def random_integers(self,
+ low,
+ high=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
low = _as_jax_array(low)
high = _as_jax_array(high)
low = _check_py_seq(low)
@@ -606,29 +616,34 @@ def random_integers(self, low, high=None, size=None, key=None):
maxval=high)
return _return(r)
- def randn(self, *dn, key=None):
+ def randn(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.normal(key, shape=dn)
return _return(r)
- def random(self, size=None, key=None):
+ def random(self,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.)
return _return(r)
- def random_sample(self, size=None, key=None):
+ def random_sample(self,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
r = self.random(size=size, key=key)
return _return(r)
- def ranf(self, size=None, key=None):
+ def ranf(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r = self.random(size=size, key=key)
return _return(r)
- def sample(self, size=None, key=None):
+ def sample(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r = self.random(size=size, key=key)
return _return(r)
- def choice(self, a, size=None, replace=True, p=None, key=None):
+ def choice(self, a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _as_jax_array(a)
p = _as_jax_array(p)
a = _check_py_seq(a)
@@ -637,21 +652,23 @@ def choice(self, a, size=None, replace=True, p=None, key=None):
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
return _return(r)
- def permutation(self, x, axis: int = 0, independent: bool = False, key=None):
+ def permutation(self, x, axis: int = 0, independent: bool = False, key: Optional[Union[int, JAX_RAND_KEY]] = None):
x = x.value if isinstance(x, Array) else x
x = _check_py_seq(x)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.permutation(key, x, axis=axis, independent=independent)
return _return(r)
- def shuffle(self, x, axis=0, key=None):
+ def shuffle(self, x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None):
if not isinstance(x, Array):
raise TypeError('This numpy operator needs in-place updating, therefore '
'inputs should be brainpy Array.')
key = self.split_key() if key is None else _formalize_key(key)
x.value = jr.permutation(key, x.value, axis=axis)
- def beta(self, a, b, size=None, key=None):
+ def beta(self, a, b,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = a.value if isinstance(a, Array) else a
b = b.value if isinstance(b, Array) else b
a = _check_py_seq(a)
@@ -662,7 +679,9 @@ def beta(self, a, b, size=None, key=None):
r = jr.beta(key, a=a, b=b, shape=_size2shape(size))
return _return(r)
- def exponential(self, scale=None, size=None, key=None):
+ def exponential(self, scale=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
scale = _as_jax_array(scale)
scale = _check_py_seq(scale)
if size is None:
@@ -673,7 +692,9 @@ def exponential(self, scale=None, size=None, key=None):
r = r / scale
return _return(r)
- def gamma(self, shape, scale=None, size=None, key=None):
+ def gamma(self, shape, scale=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
shape = _as_jax_array(shape)
scale = _as_jax_array(scale)
shape = _check_py_seq(shape)
@@ -686,7 +707,9 @@ def gamma(self, shape, scale=None, size=None, key=None):
r = r * scale
return _return(r)
- def gumbel(self, loc=None, scale=None, size=None, key=None):
+ def gumbel(self, loc=None, scale=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
@@ -697,7 +720,9 @@ def gumbel(self, loc=None, scale=None, size=None, key=None):
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size)))
return _return(r)
- def laplace(self, loc=None, scale=None, size=None, key=None):
+ def laplace(self, loc=None, scale=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
@@ -708,7 +733,9 @@ def laplace(self, loc=None, scale=None, size=None, key=None):
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size)))
return _return(r)
- def logistic(self, loc=None, scale=None, size=None, key=None):
+ def logistic(self, loc=None, scale=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
@@ -719,7 +746,9 @@ def logistic(self, loc=None, scale=None, size=None, key=None):
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size)))
return _return(r)
- def normal(self, loc=None, scale=None, size=None, key=None):
+ def normal(self, loc=None, scale=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
@@ -730,7 +759,9 @@ def normal(self, loc=None, scale=None, size=None, key=None):
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size)))
return _return(r)
- def pareto(self, a, size=None, key=None):
+ def pareto(self, a,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _as_jax_array(a)
a = _check_py_seq(a)
if size is None:
@@ -739,7 +770,9 @@ def pareto(self, a, size=None, key=None):
r = jr.pareto(key, b=a, shape=_size2shape(size))
return _return(r)
- def poisson(self, lam=1.0, size=None, key=None):
+ def poisson(self, lam=1.0,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
lam = _check_py_seq(_as_jax_array(lam))
if size is None:
size = jnp.shape(lam)
@@ -747,17 +780,24 @@ def poisson(self, lam=1.0, size=None, key=None):
r = jr.poisson(key, lam=lam, shape=_size2shape(size))
return _return(r)
- def standard_cauchy(self, size=None, key=None):
+ def standard_cauchy(self,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.cauchy(key, shape=_size2shape(size))
return _return(r)
- def standard_exponential(self, size=None, key=None):
+ def standard_exponential(self,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.exponential(key, shape=_size2shape(size))
return _return(r)
- def standard_gamma(self, shape, size=None, key=None):
+ def standard_gamma(self,
+ shape,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
shape = _as_jax_array(shape)
shape = _check_py_seq(shape)
if size is None:
@@ -766,12 +806,16 @@ def standard_gamma(self, shape, size=None, key=None):
r = jr.gamma(key, a=shape, shape=_size2shape(size))
return _return(r)
- def standard_normal(self, size=None, key=None):
+ def standard_normal(self,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.normal(key, shape=_size2shape(size))
return _return(r)
- def standard_t(self, df, size=None, key=None):
+ def standard_t(self, df,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _as_jax_array(df)
df = _check_py_seq(df)
if size is None:
@@ -780,7 +824,9 @@ def standard_t(self, df, size=None, key=None):
r = jr.t(key, df=df, shape=_size2shape(size))
return _return(r)
- def uniform(self, low=0.0, high=1.0, size=None, key=None):
+ def uniform(self, low=0.0, high=1.0,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
low = _as_jax_array(low)
high = _as_jax_array(high)
low = _check_py_seq(low)
@@ -791,39 +837,82 @@ def uniform(self, low=0.0, high=1.0, size=None, key=None):
r = jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high)
return _return(r)
- def truncated_normal(self, lower, upper, size=None, scale=None, key=None):
- lower = _as_jax_array(lower)
- lower = _check_py_seq(lower)
- upper = _as_jax_array(upper)
- upper = _check_py_seq(upper)
- scale = _as_jax_array(scale)
- scale = _check_py_seq(scale)
+ def __norm_cdf(self, x, sqrt2, dtype):
+ # Computes standard normal cumulative distribution function
+ return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
+
+ def truncated_normal(self,
+ lower,
+ upper,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ loc=0.,
+ scale=1.,
+ dtype=float,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ lower = _check_py_seq(_as_jax_array(lower))
+ upper = _check_py_seq(_as_jax_array(upper))
+ loc = _check_py_seq(_as_jax_array(loc))
+ scale = _check_py_seq(_as_jax_array(scale))
+
+ lower = lax.convert_element_type(lower, dtype)
+ upper = lax.convert_element_type(upper, dtype)
+ loc = lax.convert_element_type(loc, dtype)
+ scale = lax.convert_element_type(scale, dtype)
+
+ jit_error_checking_no_args(
+ jnp.any(jnp.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
+ ValueError("mean is more than 2 std from [lower, upper] in truncated_normal. "
+ "The distribution of values may be incorrect.")
+ )
+
if size is None:
size = lax.broadcast_shapes(jnp.shape(lower),
jnp.shape(upper),
+ jnp.shape(loc),
jnp.shape(scale))
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ sqrt2 = np.array(np.sqrt(2), dtype)
+ l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
+ u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
key = self.split_key() if key is None else _formalize_key(key)
- rands = jr.truncated_normal(key,
- lower=lower,
- upper=upper,
- shape=_size2shape(size))
- if scale is not None:
- rands = rands * scale
- return _return(rands)
+ out = jr.uniform(key, size, dtype,
+ minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
+ maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype)))
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ out = lax.erf_inv(out)
+
+ # Transform to proper mean, std
+ out = out * scale * sqrt2 + loc
+
+ # Clamp to ensure it's in the proper range
+ out = jnp.clip(out,
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
+ return _return(out)
def _check_p(self, p):
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
- def bernoulli(self, p, size=None, key=None):
+ def bernoulli(self, p, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
p = _check_py_seq(_as_jax_array(p))
- jit_error(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
+ jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
if size is None:
size = jnp.shape(p)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
return _return(r)
- def lognormal(self, mean=None, sigma=None, size=None, key=None):
+ def lognormal(self, mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
mean = _check_py_seq(_as_jax_array(mean))
sigma = _check_py_seq(_as_jax_array(sigma))
if size is None:
@@ -835,17 +924,19 @@ def lognormal(self, mean=None, sigma=None, size=None, key=None):
samples = jnp.exp(samples)
return _return(samples)
- def binomial(self, n, p, size=None, key=None):
+ def binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
n = _check_py_seq(n.value if isinstance(n, Array) else n)
p = _check_py_seq(p.value if isinstance(p, Array) else p)
- jit_error(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
+ jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
if size is None:
size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
key = self.split_key() if key is None else _formalize_key(key)
r = _binomial(key, p, n, shape=_size2shape(size))
return _return(r)
- def chisquare(self, df, size=None, key=None):
+ def chisquare(self, df, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _check_py_seq(_as_jax_array(df))
key = self.split_key() if key is None else _formalize_key(key)
if size is None:
@@ -859,13 +950,15 @@ def chisquare(self, df, size=None, key=None):
dist = dist.sum(axis=0)
return _return(dist)
- def dirichlet(self, alpha, size=None, key=None):
+ def dirichlet(self, alpha, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
alpha = _check_py_seq(_as_jax_array(alpha))
r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size))
return _return(r)
- def geometric(self, p, size=None, key=None):
+ def geometric(self, p, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
p = _as_jax_array(p)
p = _check_py_seq(p)
if size is None:
@@ -878,11 +971,12 @@ def geometric(self, p, size=None, key=None):
def _check_p2(self, p):
raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
- def multinomial(self, n, pvals, size=None, key=None):
+ def multinomial(self, n, pvals, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
n = _check_py_seq(_as_jax_array(n))
pvals = _check_py_seq(_as_jax_array(pvals))
- jit_error(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
+ jit_error_checking(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
if isinstance(n, jax.core.Tracer):
raise ValueError("The total count parameter `n` should not be a jax abstract array.")
size = _size2shape(size)
@@ -891,7 +985,8 @@ def multinomial(self, n, pvals, size=None, key=None):
r = _multinomial(key, pvals, n, n_max, batch_shape + size)
return _return(r)
- def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky', key=None):
+ def multivariate_normal(self, mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky',
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
if method not in {'svd', 'eigh', 'cholesky'}:
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
mean = _check_py_seq(_as_jax_array(mean))
@@ -924,7 +1019,8 @@ def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky', ke
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return _return(r)
- def rayleigh(self, scale=1.0, size=None, key=None):
+ def rayleigh(self, scale=1.0, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
scale = _check_py_seq(_as_jax_array(scale))
if size is None:
size = jnp.shape(scale)
@@ -933,13 +1029,15 @@ def rayleigh(self, scale=1.0, size=None, key=None):
r = x * scale
return _return(r)
- def triangular(self, size=None, key=None):
+ def triangular(self, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
r = 2 * bernoulli_samples - 1
return _return(r)
- def vonmises(self, mu, kappa, size=None, key=None):
+ def vonmises(self, mu, kappa, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
mu = _check_py_seq(_as_jax_array(mu))
kappa = _check_py_seq(_as_jax_array(kappa))
@@ -951,7 +1049,8 @@ def vonmises(self, mu, kappa, size=None, key=None):
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
return _return(samples)
- def weibull(self, a, size=None, key=None):
+ def weibull(self, a, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
if size is None:
@@ -964,7 +1063,8 @@ def weibull(self, a, size=None, key=None):
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
return _return(r)
- def weibull_min(self, a, scale=None, size=None, key=None):
+ def weibull_min(self, a, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample from a Weibull minimum distribution.
Parameters
@@ -996,14 +1096,15 @@ def weibull_min(self, a, scale=None, size=None, key=None):
r /= scale
return _return(r)
- def maxwell(self, size=None, key=None):
+ def maxwell(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
shape = core.canonicalize_shape(_size2shape(size)) + (3,)
norm_rvs = jr.normal(key=key, shape=shape)
r = jnp.linalg.norm(norm_rvs, axis=-1)
return _return(r)
- def negative_binomial(self, n, p, size=None, key=None):
+ def negative_binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
n = _check_py_seq(_as_jax_array(n))
p = _check_py_seq(_as_jax_array(p))
if size is None:
@@ -1018,7 +1119,8 @@ def negative_binomial(self, n, p, size=None, key=None):
r = self.poisson(lam=rate, key=keys[1])
return _return(r)
- def wald(self, mean, scale, size=None, key=None):
+ def wald(self, mean, scale, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
mean = _check_py_seq(_as_jax_array(mean))
scale = _check_py_seq(_as_jax_array(scale))
@@ -1058,7 +1160,7 @@ def wald(self, mean, scale, size=None, key=None):
jnp.square(mean) / sampled)
return _return(res)
- def t(self, df, size=None, key=None):
+ def t(self, df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _check_py_seq(_as_jax_array(df))
if size is None:
size = np.shape(df)
@@ -1076,7 +1178,8 @@ def t(self, df, size=None, key=None):
r = n * jnp.sqrt(half_df / g)
return _return(r)
- def orthogonal(self, n: int, size=None, key=None):
+ def orthogonal(self, n: int, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
size = _size2shape(size)
_check_shape("orthogonal", size)
@@ -1087,7 +1190,8 @@ def orthogonal(self, n: int, size=None, key=None):
r = q * jnp.expand_dims(d / abs(d), -2)
return _return(r)
- def noncentral_chisquare(self, df, nonc, size=None, key=None):
+ def noncentral_chisquare(self, df, nonc, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _check_py_seq(_as_jax_array(df))
nonc = _check_py_seq(_as_jax_array(nonc))
if size is None:
@@ -1105,7 +1209,8 @@ def noncentral_chisquare(self, df, nonc, size=None, key=None):
r = jnp.where(cond, chi2 + n * n, chi2)
return _return(r)
- def loggamma(self, a, size=None, key=None):
+ def loggamma(self, a, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
if size is None:
@@ -1113,7 +1218,8 @@ def loggamma(self, a, size=None, key=None):
r = jr.loggamma(key, a, shape=_size2shape(size))
return _return(r)
- def categorical(self, logits, axis: int = -1, size=None, key=None):
+ def categorical(self, logits, axis: int = -1, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
logits = _check_py_seq(_as_jax_array(logits))
if size is None:
@@ -1122,7 +1228,7 @@ def categorical(self, logits, axis: int = -1, size=None, key=None):
r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
return _return(r)
- def zipf(self, a, size=None, key=None):
+ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
@@ -1131,7 +1237,7 @@ def zipf(self, a, size=None, key=None):
result_shape=jax.ShapeDtypeStruct(size, jnp.int_))
return _return(r)
- def power(self, a, size=None, key=None):
+ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
@@ -1140,7 +1246,8 @@ def power(self, a, size=None, key=None):
a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))
return _return(r)
- def f(self, dfnum, dfden, size=None, key=None):
+ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
dfnum = _as_jax_array(dfnum)
dfden = _as_jax_array(dfden)
dfnum = _check_py_seq(dfnum)
@@ -1156,7 +1263,8 @@ def f(self, dfnum, dfden, size=None, key=None):
result_shape=jax.ShapeDtypeStruct(size, jnp.float_))
return _return(r)
- def hypergeometric(self, ngood, nbad, nsample, size=None, key=None):
+ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
ngood = _check_py_seq(_as_jax_array(ngood))
nbad = _check_py_seq(_as_jax_array(nbad))
nsample = _check_py_seq(_as_jax_array(nsample))
@@ -1174,7 +1282,8 @@ def hypergeometric(self, ngood, nbad, nsample, size=None, key=None):
d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))
return _return(r)
- def logseries(self, p, size=None, key=None):
+ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
p = _check_py_seq(_as_jax_array(p))
if size is None:
size = jnp.shape(p)
@@ -1183,7 +1292,8 @@ def logseries(self, p, size=None, key=None):
p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))
return _return(r)
- def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None):
+ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
dfnum = _check_py_seq(_as_jax_array(dfnum))
dfden = _check_py_seq(_as_jax_array(dfden))
nonc = _check_py_seq(_as_jax_array(nonc))
@@ -1203,7 +1313,7 @@ def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None):
# PyTorch compatibility #
# --------------------- #
- def rand_like(self, input, *, dtype=None, key=None):
+ def rand_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Returns a tensor with the same size as input that is filled with random
numbers from a uniform distribution on the interval ``[0, 1)``.
@@ -1217,7 +1327,7 @@ def rand_like(self, input, *, dtype=None, key=None):
"""
return self.random(shape(input), key=key).astype(dtype)
- def randn_like(self, input, *, dtype=None, key=None):
+ def randn_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Returns a tensor with the same size as ``input`` that is filled with
random numbers from a normal distribution with mean 0 and variance 1.
@@ -1231,7 +1341,7 @@ def randn_like(self, input, *, dtype=None, key=None):
"""
return self.randn(*shape(input), key=key).astype(dtype)
- def randint_like(self, input, low=0, high=None, *, dtype=None, key=None):
+ def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
if high is None:
high = max(input)
return self.randint(low, high=high, size=shape(input), dtype=dtype, key=key)
@@ -1248,6 +1358,9 @@ def randint_like(self, input, low=0, high=None, *, dtype=None, key=None):
def split_key():
+ """Create a new seed from the current seed.
+
+ This function is useful for the consistency with JAX's random paradigm."""
return DEFAULT.split_key()
@@ -1282,7 +1395,7 @@ def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
return RandomState(seed_or_key)
-def default_rng(seed_or_key=None, clone=True) -> RandomState:
+def default_rng(seed_or_key=None, clone: bool = True) -> RandomState:
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
@@ -1304,7 +1417,7 @@ def seed(seed: int = None):
DEFAULT.seed(seed)
-def rand(*dn, key=None):
+def rand(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""Random values in a given shape.
.. note::
@@ -1342,7 +1455,8 @@ def rand(*dn, key=None):
return DEFAULT.rand(*dn, key=key)
-def randint(low, high=None, size=None, dtype=int, key=None):
+def randint(low, high=None, size: Optional[Union[int, Sequence[int]]] = None, dtype=int,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""Return random integers from `low` (inclusive) to `high` (exclusive).
Return random integers from the "discrete uniform" distribution of
@@ -1414,7 +1528,10 @@ def randint(low, high=None, size=None, dtype=int, key=None):
return DEFAULT.randint(low, high=high, size=size, dtype=dtype, key=key)
-def random_integers(low, high=None, size=None, key=None):
+def random_integers(low,
+ high=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Random integers of type `np.int_` between `low` and `high`, inclusive.
@@ -1492,7 +1609,7 @@ def random_integers(low, high=None, size=None, key=None):
return DEFAULT.random_integers(low, high=high, size=size, key=key)
-def randn(*dn, key=None):
+def randn(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Return a sample (or samples) from the "standard normal" distribution.
@@ -1529,7 +1646,6 @@ def randn(*dn, key=None):
--------
standard_normal : Similar, but takes a tuple as its argument.
normal : Also accepts mu and sigma arguments.
- random.Generator.standard_normal: which should be used for new code.
Notes
-----
@@ -1553,15 +1669,15 @@ def randn(*dn, key=None):
return DEFAULT.randn(*dn, key=key)
-def random(size=None, key=None):
- """
+def random(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
Return random floats in the half-open interval [0.0, 1.0). Alias for
`random_sample` to ease forward-porting to the new random API.
"""
return DEFAULT.random(size, key=key)
-def random_sample(size=None, key=None):
+def random_sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Return random floats in the half-open interval [0.0, 1.0).
@@ -1612,23 +1728,24 @@ def random_sample(size=None, key=None):
return DEFAULT.random_sample(size, key=key)
-def ranf(size=None, key=None):
- """
+def ranf(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
This is an alias of `random_sample`. See `random_sample` for the complete
- documentation.
+ documentation.
"""
return DEFAULT.ranf(size, key=key)
-def sample(size=None, key=None):
+def sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""
This is an alias of `random_sample`. See `random_sample` for the complete
- documentation.
+ documentation.
"""
return DEFAULT.sample(size, key=key)
-def choice(a, size=None, replace=True, p=None, key=None):
+def choice(a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Generates a random sample from a given 1-D array
@@ -1716,7 +1833,10 @@ def choice(a, size=None, replace=True, p=None, key=None):
return DEFAULT.choice(a=a, size=size, replace=replace, p=p, key=key)
-def permutation(x, axis: int = 0, independent: bool = False, key=None):
+def permutation(x,
+ axis: int = 0,
+ independent: bool = False,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Randomly permute a sequence, or return a permuted range.
@@ -1735,10 +1855,6 @@ def permutation(x, axis: int = 0, independent: bool = False, key=None):
out : ndarray
Permuted sequence or array range.
- See Also
- --------
- random.Generator.permutation: which should be used for new code.
-
Examples
--------
>>> import brainpy.math as bm
@@ -1757,7 +1873,7 @@ def permutation(x, axis: int = 0, independent: bool = False, key=None):
return DEFAULT.permutation(x, axis=axis, independent=independent, key=key)
-def shuffle(x, axis=0, key=None):
+def shuffle(x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Modify a sequence in-place by shuffling its contents.
@@ -1774,10 +1890,6 @@ def shuffle(x, axis=0, key=None):
-------
None
- See Also
- --------
- random.Generator.shuffle: which should be used for new code.
-
Examples
--------
>>> import brainpy.math as bm
@@ -1798,7 +1910,7 @@ def shuffle(x, axis=0, key=None):
DEFAULT.shuffle(x, axis, key=key)
-def beta(a, b, size=None, key=None):
+def beta(a, b, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Beta distribution.
@@ -1832,249 +1944,451 @@ def beta(a, b, size=None, key=None):
-------
out : ndarray or scalar
Drawn samples from the parameterized beta distribution.
-
- See Also
- --------
- random.Generator.beta: which should be used for new code.
"""
return DEFAULT.beta(a, b, size=size, key=key)
-# @wraps(np.random.exponential)
-def exponential(scale=None, size=None, key=None):
- return DEFAULT.exponential(scale, size, key=key)
+def exponential(scale=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from an exponential distribution.
+ Its probability density function is
-# @wraps(np.random.gamma)
-def gamma(shape, scale=None, size=None, key=None):
- return DEFAULT.gamma(shape, scale, size=size, key=key)
+ .. math:: f(x; \frac{1}{\beta}) = \frac{1}{\beta} \exp(-\frac{x}{\beta}),
+ for ``x > 0`` and 0 elsewhere. :math:`\beta` is the scale parameter,
+ which is the inverse of the rate parameter :math:`\lambda = 1/\beta`.
+ The rate parameter is an alternative, widely used parameterization
+ of the exponential distribution [3]_.
-# @wraps(np.random.gumbel)
-def gumbel(loc=None, scale=None, size=None, key=None):
- return DEFAULT.gumbel(loc, scale, size=size, key=key)
+ The exponential distribution is a continuous analogue of the
+ geometric distribution. It describes many common situations, such as
+ the size of raindrops measured over many rainstorms [1]_, or the time
+ between page requests to Wikipedia [2]_.
+ Parameters
+ ----------
+ scale : float or array_like of floats
+ The scale parameter, :math:`\beta = 1/\lambda`. Must be
+ non-negative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``scale`` is a scalar. Otherwise,
+ ``np.array(scale).size`` samples are drawn.
-# @wraps(np.random.laplace)
-def laplace(loc=None, scale=None, size=None, key=None):
- return DEFAULT.laplace(loc, scale, size, key=key)
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized exponential distribution.
+ References
+ ----------
+ .. [1] Peyton Z. Peebles Jr., "Probability, Random Variables and
+ Random Signal Principles", 4th ed, 2001, p. 57.
+ .. [2] Wikipedia, "Poisson process",
+ https://en.wikipedia.org/wiki/Poisson_process
+ .. [3] Wikipedia, "Exponential distribution",
+ https://en.wikipedia.org/wiki/Exponential_distribution
+ """
+ return DEFAULT.exponential(scale, size, key=key)
-# @wraps(np.random.logistic)
-def logistic(loc=None, scale=None, size=None, key=None):
- return DEFAULT.logistic(loc, scale, size, key=key)
+def gamma(shape, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Gamma distribution.
-# @wraps(np.random.normal)
-def normal(loc=None, scale=None, size=None, key=None):
- return DEFAULT.normal(loc, scale, size, key=key)
+ Samples are drawn from a Gamma distribution with specified parameters,
+ `shape` (sometimes designated "k") and `scale` (sometimes designated
+ "theta"), where both parameters are > 0.
+ Parameters
+ ----------
+ shape : float or array_like of floats
+ The shape of the gamma distribution. Must be non-negative.
+ scale : float or array_like of floats, optional
+ The scale of the gamma distribution. Must be non-negative.
+ Default is equal to 1.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``shape`` and ``scale`` are both scalars.
+ Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn.
-# @wraps(np.random.pareto)
-def pareto(a, size=None, key=None):
- return DEFAULT.pareto(a, size, key=key)
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized gamma distribution.
-# @wraps(np.random.poisson)
-def poisson(lam=1.0, size=None, key=None):
- return DEFAULT.poisson(lam, size, key=key)
+ Notes
+ -----
+ The probability density for the Gamma distribution is
+ .. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)},
-# @wraps(np.random.standard_cauchy)
-def standard_cauchy(size=None, key=None):
- return DEFAULT.standard_cauchy(size, key=key)
+ where :math:`k` is the shape and :math:`\theta` the scale,
+ and :math:`\Gamma` is the Gamma function.
+ The Gamma distribution is often used to model the times to failure of
+ electronic components, and arises naturally in processes for which the
+ waiting times between Poisson distributed events are relevant.
-# @wraps(np.random.standard_exponential)
-def standard_exponential(size=None, key=None):
- return DEFAULT.standard_exponential(size, key=key)
+ References
+ ----------
+ .. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A
+ Wolfram Web Resource.
+ http://mathworld.wolfram.com/GammaDistribution.html
+ .. [2] Wikipedia, "Gamma distribution",
+ https://en.wikipedia.org/wiki/Gamma_distribution
+ """
+ return DEFAULT.gamma(shape, scale, size=size, key=key)
-# @wraps(np.random.standard_gamma)
-def standard_gamma(shape, size=None, key=None):
- return DEFAULT.standard_gamma(shape, size, key=key)
+def gumbel(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Gumbel distribution.
-# @wraps(np.random.standard_normal)
-def standard_normal(size=None, key=None):
- return DEFAULT.standard_normal(size, key=key)
+ Draw samples from a Gumbel distribution with specified location and
+ scale. For more information on the Gumbel distribution, see
+ Notes and References below.
+ Parameters
+ ----------
+ loc : float or array_like of floats, optional
+ The location of the mode of the distribution. Default is 0.
+ scale : float or array_like of floats, optional
+ The scale parameter of the distribution. Default is 1. Must be non-
+ negative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``loc`` and ``scale`` are both scalars.
+ Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
-# @wraps(np.random.standard_t)
-def standard_t(df, size=None, key=None):
- return DEFAULT.standard_t(df, size, key=key)
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized Gumbel distribution.
+ Notes
+ -----
+ The Gumbel (or Smallest Extreme Value (SEV) or the Smallest Extreme
+ Value Type I) distribution is one of a class of Generalized Extreme
+ Value (GEV) distributions used in modeling extreme value problems.
+ The Gumbel is a special case of the Extreme Value Type I distribution
+ for maximums from distributions with "exponential-like" tails.
-# @wraps(np.random.uniform)
-def uniform(low=0.0, high=1.0, size=None, key=None):
- return DEFAULT.uniform(low, high, size, key=key)
+ The probability density for the Gumbel distribution is
+ .. math:: p(x) = \frac{e^{-(x - \mu)/ \beta}}{\beta} e^{ -e^{-(x - \mu)/
+ \beta}},
-def truncated_normal(lower, upper, size=None, scale=None, key=None):
- """Sample truncated standard normal random values with given shape and dtype.
+ where :math:`\mu` is the mode, a location parameter, and
+ :math:`\beta` is the scale parameter.
- Parameters
- ----------
- lower : float, ndarray
- A float or array of floats representing the lower bound for
- truncation. Must be broadcast-compatible with ``upper``.
- upper : float, ndarray
- A float or array of floats representing the upper bound for
- truncation. Must be broadcast-compatible with ``lower``.
- size : optional, list of int, tuple of int
- A tuple of nonnegative integers specifying the result
- shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
- default (None) produces a result shape by broadcasting ``lower`` and
- ``upper``.
- scale : float, ndarray
- Standard deviation (spread or "width") of the distribution. Must be
- non-negative.
+ The Gumbel (named for German mathematician Emil Julius Gumbel) was used
+ very early in the hydrology literature, for modeling the occurrence of
+ flood events. It is also used for modeling maximum wind speed and
+ rainfall rates. It is a "fat-tailed" distribution - the probability of
+ an event in the tail of the distribution is larger than if one used a
+ Gaussian, hence the surprisingly frequent occurrence of 100-year
+ floods. Floods were initially modeled as a Gaussian process, which
+ underestimated the frequency of extreme events.
- Returns
- -------
- out : Array
- A random array with the specified dtype and shape given by ``shape`` if
- ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
- Returns values in the open interval ``(lower, upper)``.
+ It is one of a class of extreme value distributions, the Generalized
+ Extreme Value (GEV) distributions, which also includes the Weibull and
+ Frechet.
+
+ The function has a mean of :math:`\mu + 0.57721\beta` and a variance
+ of :math:`\frac{\pi^2}{6}\beta^2`.
+
+ References
+ ----------
+ .. [1] Gumbel, E. J., "Statistics of Extremes,"
+ New York: Columbia University Press, 1958.
+ .. [2] Reiss, R.-D. and Thomas, M., "Statistical Analysis of Extreme
+ Values from Insurance, Finance, Hydrology and Other Fields,"
+ Basel: Birkhauser Verlag, 2001.
"""
- return DEFAULT.truncated_normal(lower, upper, size, scale, key=key)
+ return DEFAULT.gumbel(loc, scale, size=size, key=key)
+
+def laplace(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from the Laplace or double exponential distribution with
+ specified location (or mean) and scale (decay).
-def bernoulli(p=0.5, size=None, key=None):
- """Sample Bernoulli random values with given shape and mean.
+ The Laplace distribution is similar to the Gaussian/normal distribution,
+ but is sharper at the peak and has fatter tails. It represents the
+ difference between two independent, identically distributed exponential
+ random variables.
Parameters
----------
- p: float, array_like, optional
- A float or array of floats for the mean of the random
- variables. Must be broadcast-compatible with ``shape`` and the values
- should be within [0, 1]. Default 0.5.
- size: optional, tuple of int, int
- A tuple of nonnegative integers representing the result
- shape. Must be broadcast-compatible with ``p.shape``. The default (None)
- produces a result shape equal to ``p.shape``.
+ loc : float or array_like of floats, optional
+ The position, :math:`\mu`, of the distribution peak. Default is 0.
+ scale : float or array_like of floats, optional
+ :math:`\lambda`, the exponential decay. Default is 1. Must be non-
+ negative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``loc`` and ``scale`` are both scalars.
+ Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
Returns
-------
- out: array_like
- A random array with boolean dtype and shape given by ``shape`` if ``shape``
- is not None, or else ``p.shape``.
- """
- return DEFAULT.bernoulli(p, size, key=key)
+ out : ndarray or scalar
+ Drawn samples from the parameterized Laplace distribution.
+ Notes
+ -----
+ It has the probability density function
-# @wraps(np.random.lognormal)
-def lognormal(mean=None, sigma=None, size=None, key=None):
- return DEFAULT.lognormal(mean, sigma, size, key=key)
+ .. math:: f(x; \mu, \lambda) = \frac{1}{2\lambda}
+ \exp\left(-\frac{|x - \mu|}{\lambda}\right).
+ The first law of Laplace, from 1774, states that the frequency
+ of an error can be expressed as an exponential function of the
+ absolute magnitude of the error, which leads to the Laplace
+ distribution. For many problems in economics and health
+ sciences, this distribution seems to model the data better
+ than the standard Gaussian distribution.
-# @wraps(np.random.binomial)
-def binomial(n, p, size=None, key=None):
- return DEFAULT.binomial(n, p, size, key=key)
+ References
+ ----------
+ .. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of
+ Mathematical Functions with Formulas, Graphs, and Mathematical
+ Tables, 9th printing," New York: Dover, 1972.
+ .. [2] Kotz, Samuel, et. al. "The Laplace Distribution and
+ Generalizations, " Birkhauser, 2001.
+ .. [3] Weisstein, Eric W. "Laplace Distribution."
+ From MathWorld--A Wolfram Web Resource.
+ http://mathworld.wolfram.com/LaplaceDistribution.html
+ .. [4] Wikipedia, "Laplace distribution",
+ https://en.wikipedia.org/wiki/Laplace_distribution
+ Examples
+ --------
+ Draw samples from the distribution
-# @wraps(np.random.chisquare)
-def chisquare(df, size=None, key=None):
- return DEFAULT.chisquare(df, size, key=key)
+ >>> loc, scale = 0., 1.
+ >>> s = bm.random.laplace(loc, scale, 1000)
+ Display the histogram of the samples, along with
+ the probability density function:
-# @wraps(np.random.dirichlet)
-def dirichlet(alpha, size=None, key=None):
- return DEFAULT.dirichlet(alpha, size, key=key)
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s, 30, density=True)
+ >>> x = np.arange(-8., 8., .01)
+ >>> pdf = np.exp(-abs(x-loc)/scale)/(2.*scale)
+ >>> plt.plot(x, pdf)
+ Plot Gaussian for comparison:
-# @wraps(np.random.geometric)
-def geometric(p, size=None, key=None):
- return DEFAULT.geometric(p, size, key=key)
+ >>> g = (1/(scale * np.sqrt(2 * np.pi)) *
+ ... np.exp(-(x - loc)**2 / (2 * scale**2)))
+ >>> plt.plot(x,g)
+ """
+ return DEFAULT.laplace(loc, scale, size, key=key)
-# @wraps(np.random.f)
-def f(dfnum, dfden, size=None, key=None):
- return DEFAULT.f(dfnum, dfden, size, key=key)
+def logistic(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a logistic distribution.
+ Samples are drawn from a logistic distribution with specified
+ parameters, loc (location or mean, also median), and scale (>0).
-# @wraps(np.random.hypergeometric)
-def hypergeometric(ngood, nbad, nsample, size=None, key=None):
- return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key)
+ Parameters
+ ----------
+ loc : float or array_like of floats, optional
+ Parameter of the distribution. Default is 0.
+ scale : float or array_like of floats, optional
+ Parameter of the distribution. Must be non-negative.
+ Default is 1.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``loc`` and ``scale`` are both scalars.
+ Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized logistic distribution.
-# @wraps(np.random.logseries)
-def logseries(p, size=None, key=None):
- return DEFAULT.logseries(p, size, key=key)
+ Notes
+ -----
+ The probability density for the Logistic distribution is
+ .. math:: P(x) = P(x) = \frac{e^{-(x-\mu)/s}}{s(1+e^{-(x-\mu)/s})^2},
-# @wraps(np.random.multinomial)
-def multinomial(n, pvals, size=None, key=None):
- return DEFAULT.multinomial(n, pvals, size, key=key)
+ where :math:`\mu` = location and :math:`s` = scale.
+ The Logistic distribution is used in Extreme Value problems where it
+ can act as a mixture of Gumbel distributions, in Epidemiology, and by
+ the World Chess Federation (FIDE) where it is used in the Elo ranking
+ system, assuming the performance of each player is a logistically
+ distributed random variable.
-# @wraps(np.random.multivariate_normal)
-def multivariate_normal(mean, cov, size=None, method: str = 'cholesky', key=None):
- return DEFAULT.multivariate_normal(mean, cov, size, method, key=key)
+ References
+ ----------
+ .. [1] Reiss, R.-D. and Thomas M. (2001), "Statistical Analysis of
+ Extreme Values, from Insurance, Finance, Hydrology and Other
+ Fields," Birkhauser Verlag, Basel, pp 132-133.
+ .. [2] Weisstein, Eric W. "Logistic Distribution." From
+ MathWorld--A Wolfram Web Resource.
+ http://mathworld.wolfram.com/LogisticDistribution.html
+ .. [3] Wikipedia, "Logistic-distribution",
+ https://en.wikipedia.org/wiki/Logistic_distribution
+ Examples
+ --------
+ Draw samples from the distribution:
-# @wraps(np.random.negative_binomial)
-def negative_binomial(n, p, size=None, key=None):
- return DEFAULT.negative_binomial(n, p, size, key=key)
+ >>> loc, scale = 10, 1
+ >>> s = bm.random.logistic(loc, scale, 10000)
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s, bins=50)
+ # plot against distribution
-# @wraps(np.random.noncentral_chisquare)
-def noncentral_chisquare(df, nonc, size=None, key=None):
- return DEFAULT.noncentral_chisquare(df, nonc, size, key=key)
+ >>> def logist(x, loc, scale):
+ ... return np.exp((loc-x)/scale)/(scale*(1+np.exp((loc-x)/scale))**2)
+ >>> lgst_val = logist(bins, loc, scale)
+ >>> plt.plot(bins, lgst_val * count.max() / lgst_val.max())
+ >>> plt.show()
+ """
+ return DEFAULT.logistic(loc, scale, size, key=key)
-# @wraps(np.random.noncentral_f)
-def noncentral_f(dfnum, dfden, nonc, size=None, key=None):
- return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key)
+def normal(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw random samples from a normal (Gaussian) distribution.
+ The probability density function of the normal distribution, first
+ derived by De Moivre and 200 years later by both Gauss and Laplace
+ independently [2]_, is often called the bell curve because of
+ its characteristic shape (see the example below).
-# @wraps(np.random.power)
-def power(a, size=None, key=None):
- return DEFAULT.power(a, size, key=key)
+ The normal distributions occurs often in nature. For example, it
+ describes the commonly occurring distribution of samples influenced
+ by a large number of tiny, random disturbances, each with its own
+ unique distribution [2]_.
+ Parameters
+ ----------
+ loc : float or array_like of floats
+ Mean ("centre") of the distribution.
+ scale : float or array_like of floats
+ Standard deviation (spread or "width") of the distribution. Must be
+ non-negative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``loc`` and ``scale`` are both scalars.
+ Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
-# @wraps(np.random.rayleigh)
-def rayleigh(scale=1.0, size=None, key=None):
- return DEFAULT.rayleigh(scale, size, key=key)
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized normal distribution.
+ Notes
+ -----
+ The probability density for the Gaussian distribution is
-# @wraps(np.random.triangular)
-def triangular(size=None, key=None):
- return DEFAULT.triangular(size, key=key)
+ .. math:: p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }}
+ e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} },
+ where :math:`\mu` is the mean and :math:`\sigma` the standard
+ deviation. The square of the standard deviation, :math:`\sigma^2`,
+ is called the variance.
-# @wraps(np.random.vonmises)
-def vonmises(mu, kappa, size=None, key=None):
- return DEFAULT.vonmises(mu, kappa, size, key=key)
+ The function has its peak at the mean, and its "spread" increases with
+ the standard deviation (the function reaches 0.607 times its maximum at
+ :math:`x + \sigma` and :math:`x - \sigma` [2]_). This implies that
+ normal is more likely to return samples lying close to the mean, rather
+ than those far away.
+ References
+ ----------
+ .. [1] Wikipedia, "Normal distribution",
+ https://en.wikipedia.org/wiki/Normal_distribution
+ .. [2] P. R. Peebles Jr., "Central Limit Theorem" in "Probability,
+ Random Variables and Random Signal Principles", 4th ed., 2001,
+ pp. 51, 51, 125.
-# @wraps(np.random.wald)
-def wald(mean, scale, size=None, key=None):
- return DEFAULT.wald(mean, scale, size, key=key)
+ Examples
+ --------
+ Draw samples from the distribution:
+ >>> mu, sigma = 0, 0.1 # mean and standard deviation
+ >>> s = bm.random.normal(mu, sigma, 1000)
-def weibull(a, size=None, key=None):
- r"""
- Draw samples from a Weibull distribution.
-
- Draw samples from a 1-parameter Weibull distribution with the given
- shape parameter `a`.
+ Verify the mean and the variance:
- .. math:: X = (-ln(U))^{1/a}
+ >>> abs(mu - np.mean(s))
+ 0.0 # may vary
- Here, U is drawn from the uniform distribution over (0,1].
+ >>> abs(sigma - np.std(s, ddof=1))
+ 0.1 # may vary
- The more common 2-parameter Weibull, including a scale parameter
- :math:`\lambda` is just :math:`X = \lambda(-ln(U))^{1/a}`.
+ Display the histogram of the samples, along with
+ the probability density function:
- .. note::
- New code should use the ``weibull`` method of a ``default_rng()``
- instance instead; please see the :ref:`random-quick-start`.
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s, 30, density=True)
+ >>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) *
+ ... np.exp( - (bins - mu)**2 / (2 * sigma**2) ),
+ ... linewidth=2, color='r')
+ >>> plt.show()
+
+ Two-by-four array of samples from the normal distribution with
+ mean 3 and standard deviation 2.5:
+
+ >>> bm.random.normal(3, 2.5, size=(2, 4))
+ array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
+ [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
+ """
+ return DEFAULT.normal(loc, scale, size, key=key)
+
+
+def pareto(a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Pareto II or Lomax distribution with
+ specified shape.
+
+ The Lomax or Pareto II distribution is a shifted Pareto
+ distribution. The classical Pareto distribution can be
+ obtained from the Lomax distribution by adding 1 and
+ multiplying by the scale parameter ``m`` (see Notes). The
+ smallest value of the Lomax distribution is zero while for the
+ classical Pareto distribution it is ``mu``, where the standard
+ Pareto distribution has location ``mu = 1``. Lomax can also
+ be considered as a simplified version of the Generalized
+ Pareto distribution (available in SciPy), with the scale set
+ to one and the location set to zero.
+
+ The Pareto distribution must be greater than zero, and is
+ unbounded above. It is also known as the "80-20 rule". In
+ this distribution, 80 percent of the weights are in the lowest
+ 20 percent of the range, while the other 20 percent fill the
+ remaining 80 percent of the range.
Parameters
----------
a : float or array_like of floats
- Shape parameter of the distribution. Must be nonnegative.
+ Shape of the distribution. Must be positive.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
@@ -2084,63 +2398,2112 @@ def weibull(a, size=None, key=None):
Returns
-------
out : ndarray or scalar
- Drawn samples from the parameterized Weibull distribution.
+ Drawn samples from the parameterized Pareto distribution.
See Also
--------
- scipy.stats.weibull_max
- scipy.stats.weibull_min
- scipy.stats.genextreme
- gumbel
- random.Generator.weibull: which should be used for new code.
+ scipy.stats.lomax : probability density function, distribution or
+ cumulative density function, etc.
+ scipy.stats.genpareto : probability density function, distribution or
+ cumulative density function, etc.
Notes
-----
- The Weibull (or Type III asymptotic extreme value distribution
- for smallest values, SEV Type III, or Rosin-Rammler
- distribution) is one of a class of Generalized Extreme Value
- (GEV) distributions used in modeling extreme value problems.
- This class includes the Gumbel and Frechet distributions.
-
- The probability density for the Weibull distribution is
+ The probability density for the Pareto distribution is
- .. math:: p(x) = \frac{a}
- {\lambda}(\frac{x}{\lambda})^{a-1}e^{-(x/\lambda)^a},
-
- where :math:`a` is the shape and :math:`\lambda` the scale.
+ .. math:: p(x) = \frac{am^a}{x^{a+1}}
- The function has its peak (the mode) at
- :math:`\lambda(\frac{a-1}{a})^{1/a}`.
+ where :math:`a` is the shape and :math:`m` the scale.
- When ``a = 1``, the Weibull distribution reduces to the exponential
- distribution.
+ The Pareto distribution, named after the Italian economist
+ Vilfredo Pareto, is a power law probability distribution
+ useful in many real world problems. Outside the field of
+ economics it is generally referred to as the Bradford
+ distribution. Pareto developed the distribution to describe
+ the distribution of wealth in an economy. It has also found
+ use in insurance, web page access statistics, oil field sizes,
+ and many other problems, including the download frequency for
+ projects in Sourceforge [1]_. It is one of the so-called
+ "fat-tailed" distributions.
References
----------
- .. [1] Waloddi Weibull, Royal Technical University, Stockholm,
- 1939 "A Statistical Theory Of The Strength Of Materials",
- Ingeniorsvetenskapsakademiens Handlingar Nr 151, 1939,
- Generalstabens Litografiska Anstalts Forlag, Stockholm.
- .. [2] Waloddi Weibull, "A Statistical Distribution Function of
- Wide Applicability", Journal Of Applied Mechanics ASME Paper
- 1951.
- .. [3] Wikipedia, "Weibull distribution",
- https://en.wikipedia.org/wiki/Weibull_distribution
+ .. [1] Francis Hunt and Paul Johnson, On the Pareto Distribution of
+ Sourceforge projects.
+ .. [2] Pareto, V. (1896). Course of Political Economy. Lausanne.
+ .. [3] Reiss, R.D., Thomas, M.(2001), Statistical Analysis of Extreme
+ Values, Birkhauser Verlag, Basel, pp 23-30.
+ .. [4] Wikipedia, "Pareto distribution",
+ https://en.wikipedia.org/wiki/Pareto_distribution
Examples
--------
Draw samples from the distribution:
- >>> a = 5. # shape
- >>> s = brainpy.math.random.weibull(a, 1000)
+ >>> a, m = 3., 2. # shape and mode
+ >>> s = (bm.random.pareto(a, 1000) + 1) * m
- Display the histogram of the samples, along with
- the probability density function:
+ Display the histogram of the samples, along with the probability
+ density function:
>>> import matplotlib.pyplot as plt
- >>> x = np.arange(1,100.)/50.
- >>> def weib(x,n,a):
- ... return (a / n) * (x / n)**(a - 1) * np.exp(-(x / n)**a)
+ >>> count, bins, _ = plt.hist(s, 100, density=True)
+ >>> fit = a*m**a / bins**(a+1)
+ >>> plt.plot(bins, max(count)*fit/max(fit), linewidth=2, color='r')
+ >>> plt.show()
+ """
+ return DEFAULT.pareto(a, size, key=key)
+
+
+def poisson(lam=1.0, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Poisson distribution.
+
+ The Poisson distribution is the limit of the binomial distribution
+ for large N.
+
+ Parameters
+ ----------
+ lam : float or array_like of floats
+ Expected number of events occurring in a fixed-time interval,
+ must be >= 0. A sequence must be broadcastable over the requested
+ size.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``lam`` is a scalar. Otherwise,
+ ``np.array(lam).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized Poisson distribution.
+
+ Notes
+ -----
+ The Poisson distribution
+
+ .. math:: f(k; \lambda)=\frac{\lambda^k e^{-\lambda}}{k!}
+
+ For events with an expected separation :math:`\lambda` the Poisson
+ distribution :math:`f(k; \lambda)` describes the probability of
+ :math:`k` events occurring within the observed
+ interval :math:`\lambda`.
+
+ Because the output is limited to the range of the C int64 type, a
+ ValueError is raised when `lam` is within 10 sigma of the maximum
+ representable value.
+
+ References
+ ----------
+ .. [1] Weisstein, Eric W. "Poisson Distribution."
+ From MathWorld--A Wolfram Web Resource.
+ http://mathworld.wolfram.com/PoissonDistribution.html
+ .. [2] Wikipedia, "Poisson distribution",
+ https://en.wikipedia.org/wiki/Poisson_distribution
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> import numpy as np
+ >>> s = bm.random.poisson(5, 10000)
+
+ Display histogram of the sample:
+
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s, 14, density=True)
+ >>> plt.show()
+
+ Draw each 100 values for lambda 100 and 500:
+
+ >>> s = bm.random.poisson(lam=(100., 500.), size=(100, 2))
+ """
+ return DEFAULT.poisson(lam, size, key=key)
+
+
+def standard_cauchy(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a standard Cauchy distribution with mode = 0.
+
+ Also known as the Lorentz distribution.
+
+ Parameters
+ ----------
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. Default is None, in which case a
+ single value is returned.
+
+ Returns
+ -------
+ samples : ndarray or scalar
+ The drawn samples.
+
+ Notes
+ -----
+ The probability density function for the full Cauchy distribution is
+
+ .. math:: P(x; x_0, \gamma) = \frac{1}{\pi \gamma \bigl[ 1+
+ (\frac{x-x_0}{\gamma})^2 \bigr] }
+
+ and the Standard Cauchy distribution just sets :math:`x_0=0` and
+ :math:`\gamma=1`
+
+ The Cauchy distribution arises in the solution to the driven harmonic
+ oscillator problem, and also describes spectral line broadening. It
+ also describes the distribution of values at which a line tilted at
+ a random angle will cut the x axis.
+
+ When studying hypothesis tests that assume normality, seeing how the
+ tests perform on data from a Cauchy distribution is a good indicator of
+ their sensitivity to a heavy-tailed distribution, since the Cauchy looks
+ very much like a Gaussian distribution, but with heavier tails.
+
+ References
+ ----------
+ .. [1] NIST/SEMATECH e-Handbook of Statistical Methods, "Cauchy
+ Distribution",
+ https://www.itl.nist.gov/div898/handbook/eda/section3/eda3663.htm
+ .. [2] Weisstein, Eric W. "Cauchy Distribution." From MathWorld--A
+ Wolfram Web Resource.
+ http://mathworld.wolfram.com/CauchyDistribution.html
+ .. [3] Wikipedia, "Cauchy distribution"
+ https://en.wikipedia.org/wiki/Cauchy_distribution
+
+ Examples
+ --------
+ Draw samples and plot the distribution:
+
+ >>> import matplotlib.pyplot as plt
+ >>> s = bm.random.standard_cauchy(1000000)
+ >>> s = s[(s>-25) & (s<25)] # truncate distribution so it plots well
+ >>> plt.hist(s, bins=100)
+ >>> plt.show()
+ """
+ return DEFAULT.standard_cauchy(size, key=key)
+
+
+def standard_exponential(size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from the standard exponential distribution.
+
+ `standard_exponential` is identical to the exponential distribution
+ with a scale parameter of 1.
+
+ Parameters
+ ----------
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. Default is None, in which case a
+ single value is returned.
+
+ Returns
+ -------
+ out : float or ndarray
+ Drawn samples.
+
+ Examples
+ --------
+ Output a 3x8000 array:
+
+ >>> n = bm.random.standard_exponential((3, 8000))
+ """
+ return DEFAULT.standard_exponential(size, key=key)
+
+
+def standard_gamma(shape, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a standard Gamma distribution.
+
+ Samples are drawn from a Gamma distribution with specified parameters,
+ shape (sometimes designated "k") and scale=1.
+
+ Parameters
+ ----------
+ shape : float or array_like of floats
+ Parameter, must be non-negative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``shape`` is a scalar. Otherwise,
+ ``np.array(shape).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized standard gamma distribution.
+
+ See Also
+ --------
+ scipy.stats.gamma : probability density function, distribution or
+ cumulative density function, etc.
+
+ Notes
+ -----
+ The probability density for the Gamma distribution is
+
+ .. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)},
+
+ where :math:`k` is the shape and :math:`\theta` the scale,
+ and :math:`\Gamma` is the Gamma function.
+
+ The Gamma distribution is often used to model the times to failure of
+ electronic components, and arises naturally in processes for which the
+ waiting times between Poisson distributed events are relevant.
+
+ References
+ ----------
+ .. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A
+ Wolfram Web Resource.
+ http://mathworld.wolfram.com/GammaDistribution.html
+ .. [2] Wikipedia, "Gamma distribution",
+ https://en.wikipedia.org/wiki/Gamma_distribution
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> shape, scale = 2., 1. # mean and width
+ >>> s = bm.random.standard_gamma(shape, 1000000)
+
+ Display the histogram of the samples, along with
+ the probability density function:
+
+ >>> import matplotlib.pyplot as plt
+ >>> import scipy.special as sps # doctest: +SKIP
+ >>> count, bins, ignored = plt.hist(s, 50, density=True)
+ >>> y = bins**(shape-1) * ((np.exp(-bins/scale))/ # doctest: +SKIP
+ ... (sps.gamma(shape) * scale**shape))
+ >>> plt.plot(bins, y, linewidth=2, color='r') # doctest: +SKIP
+ >>> plt.show()
+ """
+ return DEFAULT.standard_gamma(shape, size, key=key)
+
+
+def standard_normal(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a standard Normal distribution (mean=0, stdev=1).
+
+ Parameters
+ ----------
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. Default is None, in which case a
+ single value is returned.
+
+ Returns
+ -------
+ out : float or ndarray
+ A floating-point array of shape ``size`` of drawn samples, or a
+ single sample if ``size`` was not specified.
+
+ See Also
+ --------
+ normal :
+ Equivalent function with additional ``loc`` and ``scale`` arguments
+ for setting the mean and standard deviation.
+
+ Notes
+ -----
+ For random samples from the normal distribution with mean ``mu`` and
+ standard deviation ``sigma``, use one of::
+
+ mu + sigma * bm.random.standard_normal(size=...)
+ bm.random.normal(mu, sigma, size=...)
+
+ Examples
+ --------
+ >>> bm.random.standard_normal()
+ 2.1923875335537315 #random
+
+ >>> s = bm.random.standard_normal(8000)
+ >>> s
+ array([ 0.6888893 , 0.78096262, -0.89086505, ..., 0.49876311, # random
+ -0.38672696, -0.4685006 ]) # random
+ >>> s.shape
+ (8000,)
+ >>> s = bm.random.standard_normal(size=(3, 4, 2))
+ >>> s.shape
+ (3, 4, 2)
+
+ Two-by-four array of samples from the normal distribution with
+ mean 3 and standard deviation 2.5:
+
+ >>> 3 + 2.5 * bm.random.standard_normal(size=(2, 4))
+ array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
+ [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
+ """
+ return DEFAULT.standard_normal(size, key=key)
+
+
+def standard_t(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a standard Student's t distribution with `df` degrees
+ of freedom.
+
+ A special case of the hyperbolic distribution. As `df` gets
+ large, the result resembles that of the standard normal
+ distribution (`standard_normal`).
+
+ Parameters
+ ----------
+ df : float or array_like of floats
+ Degrees of freedom, must be > 0.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``df`` is a scalar. Otherwise,
+ ``np.array(df).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized standard Student's t distribution.
+
+ Notes
+ -----
+ The probability density function for the t distribution is
+
+ .. math:: P(x, df) = \frac{\Gamma(\frac{df+1}{2})}{\sqrt{\pi df}
+ \Gamma(\frac{df}{2})}\Bigl( 1+\frac{x^2}{df} \Bigr)^{-(df+1)/2}
+
+ The t test is based on an assumption that the data come from a
+ Normal distribution. The t test provides a way to test whether
+ the sample mean (that is the mean calculated from the data) is
+ a good estimate of the true mean.
+
+ The derivation of the t-distribution was first published in
+ 1908 by William Gosset while working for the Guinness Brewery
+ in Dublin. Due to proprietary issues, he had to publish under
+ a pseudonym, and so he used the name Student.
+
+ References
+ ----------
+ .. [1] Dalgaard, Peter, "Introductory Statistics With R",
+ Springer, 2002.
+ .. [2] Wikipedia, "Student's t-distribution"
+ https://en.wikipedia.org/wiki/Student's_t-distribution
+
+ Examples
+ --------
+ From Dalgaard page 83 [1]_, suppose the daily energy intake for 11
+ women in kilojoules (kJ) is:
+
+ >>> intake = np.array([5260., 5470, 5640, 6180, 6390, 6515, 6805, 7515, \
+ ... 7515, 8230, 8770])
+
+ Does their energy intake deviate systematically from the recommended
+ value of 7725 kJ? Our null hypothesis will be the absence of deviation,
+ and the alternate hypothesis will be the presence of an effect that could be
+ either positive or negative, hence making our test 2-tailed.
+
+ Because we are estimating the mean and we have N=11 values in our sample,
+ we have N-1=10 degrees of freedom. We set our significance level to 95% and
+ compute the t statistic using the empirical mean and empirical standard
+ deviation of our intake. We use a ddof of 1 to base the computation of our
+ empirical standard deviation on an unbiased estimate of the variance (note:
+ the final estimate is not unbiased due to the concave nature of the square
+ root).
+
+ >>> np.mean(intake)
+ 6753.636363636364
+ >>> intake.std(ddof=1)
+ 1142.1232221373727
+ >>> t = (np.mean(intake)-7725)/(intake.std(ddof=1)/np.sqrt(len(intake)))
+ >>> t
+ -2.8207540608310198
+
+ We draw 1000000 samples from Student's t distribution with the adequate
+ degrees of freedom.
+
+ >>> import matplotlib.pyplot as plt
+ >>> s = bm.random.standard_t(10, size=1000000)
+ >>> h = plt.hist(s, bins=100, density=True)
+
+ Does our t statistic land in one of the two critical regions found at
+ both tails of the distribution?
+
+ >>> np.sum(np.abs(t) < np.abs(s)) / float(len(s))
+ 0.018318 #random < 0.05, statistic is in critical region
+
+ The probability value for this 2-tailed test is about 1.83%, which is
+ lower than the 5% pre-determined significance threshold.
+
+ Therefore, the probability of observing values as extreme as our intake
+ conditionally on the null hypothesis being true is too low, and we reject
+ the null hypothesis of no deviation.
+ """
+ return DEFAULT.standard_t(df, size, key=key)
+
+
+def uniform(low=0.0, high=1.0, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a uniform distribution.
+
+ Samples are uniformly distributed over the half-open interval
+ ``[low, high)`` (includes low, but excludes high). In other words,
+ any value within the given interval is equally likely to be drawn
+ by `uniform`.
+
+ Parameters
+ ----------
+ low : float or array_like of floats, optional
+ Lower boundary of the output interval. All values generated will be
+ greater than or equal to low. The default value is 0.
+ high : float or array_like of floats
+ Upper boundary of the output interval. All values generated will be
+ less than or equal to high. The high limit may be included in the
+ returned array of floats due to floating-point rounding in the
+ equation ``low + (high-low) * random_sample()``. The default value
+ is 1.0.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``low`` and ``high`` are both scalars.
+ Otherwise, ``np.broadcast(low, high).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized uniform distribution.
+
+ See Also
+ --------
+ randint : Discrete uniform distribution, yielding integers.
+ random_integers : Discrete uniform distribution over the closed
+ interval ``[low, high]``.
+ random_sample : Floats uniformly distributed over ``[0, 1)``.
+ random : Alias for `random_sample`.
+ rand : Convenience function that accepts dimensions as input, e.g.,
+ ``rand(2,2)`` would generate a 2-by-2 array of floats,
+ uniformly distributed over ``[0, 1)``.
+
+ Notes
+ -----
+ The probability density function of the uniform distribution is
+
+ .. math:: p(x) = \frac{1}{b - a}
+
+ anywhere within the interval ``[a, b)``, and zero elsewhere.
+
+ When ``high`` == ``low``, values of ``low`` will be returned.
+ If ``high`` < ``low``, the results are officially undefined
+ and may eventually raise an error, i.e. do not rely on this
+ function to behave when passed arguments satisfying that
+ inequality condition. The ``high`` limit may be included in the
+ returned array of floats due to floating-point rounding in the
+ equation ``low + (high-low) * random_sample()``. For example:
+
+ >>> x = np.float32(5*0.99999999)
+ >>> x
+ 5.0
+
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> s = bm.random.uniform(-1,0,1000)
+
+ All values are within the given interval:
+
+ >>> np.all(s >= -1)
+ True
+ >>> np.all(s < 0)
+ True
+
+ Display the histogram of the samples, along with the
+ probability density function:
+
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s, 15, density=True)
+ >>> plt.plot(bins, np.ones_like(bins), linewidth=2, color='r')
+ >>> plt.show()
+ """
+ return DEFAULT.uniform(low, high, size, key=key)
+
+
+def truncated_normal(lower, upper, size: Optional[Union[int, Sequence[int]]] = None, loc=0., scale=1., dtype=float,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""Sample truncated standard normal random values with given shape and dtype.
+
+ Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+
+
+ Notes
+ -----
+ This distribution is the normal distribution centered on ``loc`` (default
+ 0), with standard deviation ``scale`` (default 1), and clipped at ``a``,
+ ``b`` standard deviations to the left, right (respectively) from ``loc``.
+ If ``myclip_a`` and ``myclip_b`` are clip values in the sample space (as
+ opposed to the number of standard deviations) then they can be converted
+ to the required form according to::
+
+ a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale
+
+
+ Parameters
+ ----------
+ lower : float, ndarray
+ A float or array of floats representing the lower bound for
+ truncation. Must be broadcast-compatible with ``upper``.
+ upper : float, ndarray
+ A float or array of floats representing the upper bound for
+ truncation. Must be broadcast-compatible with ``lower``.
+ loc : float, ndarray
+ Mean ("centre") of the distribution before truncating. Note that
+ the mean of the truncated distribution will not be exactly equal
+ to ``loc``.
+ size : optional, list of int, tuple of int
+ A tuple of nonnegative integers specifying the result
+ shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
+ default (None) produces a result shape by broadcasting ``lower`` and
+ ``upper``.
+ loc: optional, float, ndarray
+ A float or array of floats representing the mean of the
+ distribution. Default is 0.
+ scale : float, ndarray
+ Standard deviation (spread or "width") of the distribution. Must be
+ non-negative. Default is 1.
+ dtype: optional
+ The float dtype for the returned values (default float64 if
+ jax_enable_x64 is true, otherwise float32).
+ key: jax.Array
+ The key for random generator. Consistent with the jax's random
+ paradigm.
+
+ Returns
+ -------
+ out : Array
+ A random array with the specified dtype and shape given by ``shape`` if
+ ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
+ Returns values in the open interval ``(lower, upper)``.
+ """
+ return DEFAULT.truncated_normal(lower, upper, size, loc, scale, dtype=dtype, key=key)
+
+
+RandomState.truncated_normal.__doc__ = truncated_normal.__doc__
+
+
+def bernoulli(p=0.5, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""Sample Bernoulli random values with given shape and mean.
+
+ Parameters
+ ----------
+ p: float, array_like, optional
+ A float or array of floats for the mean of the random
+ variables. Must be broadcast-compatible with ``shape`` and the values
+ should be within [0, 1]. Default 0.5.
+ size: optional, tuple of int, int
+ A tuple of nonnegative integers representing the result
+ shape. Must be broadcast-compatible with ``p.shape``. The default (None)
+ produces a result shape equal to ``p.shape``.
+
+ Returns
+ -------
+ out: array_like
+ A random array with boolean dtype and shape given by ``shape`` if ``shape``
+ is not None, or else ``p.shape``.
+ """
+ return DEFAULT.bernoulli(p, size, key=key)
+
+
+def lognormal(mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a log-normal distribution.
+
+ Draw samples from a log-normal distribution with specified mean,
+ standard deviation, and array shape. Note that the mean and standard
+ deviation are not the values for the distribution itself, but of the
+ underlying normal distribution it is derived from.
+
+ Parameters
+ ----------
+ mean : float or array_like of floats, optional
+ Mean value of the underlying normal distribution. Default is 0.
+ sigma : float or array_like of floats, optional
+ Standard deviation of the underlying normal distribution. Must be
+ non-negative. Default is 1.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``mean`` and ``sigma`` are both scalars.
+ Otherwise, ``np.broadcast(mean, sigma).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized log-normal distribution.
+
+ See Also
+ --------
+ scipy.stats.lognorm : probability density function, distribution,
+ cumulative density function, etc.
+
+ Notes
+ -----
+ A variable `x` has a log-normal distribution if `log(x)` is normally
+ distributed. The probability density function for the log-normal
+ distribution is:
+
+ .. math:: p(x) = \frac{1}{\sigma x \sqrt{2\pi}}
+ e^{(-\frac{(ln(x)-\mu)^2}{2\sigma^2})}
+
+ where :math:`\mu` is the mean and :math:`\sigma` is the standard
+ deviation of the normally distributed logarithm of the variable.
+ A log-normal distribution results if a random variable is the *product*
+ of a large number of independent, identically-distributed variables in
+ the same way that a normal distribution results if the variable is the
+ *sum* of a large number of independent, identically-distributed
+ variables.
+
+ References
+ ----------
+ .. [1] Limpert, E., Stahel, W. A., and Abbt, M., "Log-normal
+ Distributions across the Sciences: Keys and Clues,"
+ BioScience, Vol. 51, No. 5, May, 2001.
+ https://stat.ethz.ch/~stahel/lognormal/bioscience.pdf
+ .. [2] Reiss, R.D. and Thomas, M., "Statistical Analysis of Extreme
+ Values," Basel: Birkhauser Verlag, 2001, pp. 31-32.
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> mu, sigma = 3., 1. # mean and standard deviation
+ >>> s = bm.random.lognormal(mu, sigma, 1000)
+
+ Display the histogram of the samples, along with
+ the probability density function:
+
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s, 100, density=True, align='mid')
+
+ >>> x = np.linspace(min(bins), max(bins), 10000)
+ >>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2))
+ ... / (x * sigma * np.sqrt(2 * np.pi)))
+
+ >>> plt.plot(x, pdf, linewidth=2, color='r')
+ >>> plt.axis('tight')
+ >>> plt.show()
+
+ Demonstrate that taking the products of random samples from a uniform
+ distribution can be fit well by a log-normal probability density
+ function.
+
+ >>> # Generate a thousand samples: each is the product of 100 random
+ >>> # values, drawn from a normal distribution.
+ >>> b = []
+ >>> for i in range(1000):
+ ... a = 10. + bm.random.standard_normal(100)
+ ... b.append(np.product(a))
+
+ >>> b = np.array(b) / np.min(b) # scale values to be positive
+ >>> count, bins, ignored = plt.hist(b, 100, density=True, align='mid')
+ >>> sigma = np.std(np.log(b))
+ >>> mu = np.mean(np.log(b))
+
+ >>> x = np.linspace(min(bins), max(bins), 10000)
+ >>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2))
+ ... / (x * sigma * np.sqrt(2 * np.pi)))
+
+ >>> plt.plot(x, pdf, color='r', linewidth=2)
+ >>> plt.show()
+ """
+ return DEFAULT.lognormal(mean, sigma, size, key=key)
+
+
+def binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a binomial distribution.
+
+ Samples are drawn from a binomial distribution with specified
+ parameters, n trials and p probability of success where
+ n an integer >= 0 and p is in the interval [0,1]. (n may be
+ input as a float, but it is truncated to an integer in use)
+
+ Parameters
+ ----------
+ n : int or array_like of ints
+ Parameter of the distribution, >= 0. Floats are also accepted,
+ but they will be truncated to integers.
+ p : float or array_like of floats
+ Parameter of the distribution, >= 0 and <=1.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``n`` and ``p`` are both scalars.
+ Otherwise, ``np.broadcast(n, p).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized binomial distribution, where
+ each sample is equal to the number of successes over the n trials.
+
+ See Also
+ --------
+ scipy.stats.binom : probability density function, distribution or
+ cumulative density function, etc.
+
+ Notes
+ -----
+ The probability density for the binomial distribution is
+
+ .. math:: P(N) = \binom{n}{N}p^N(1-p)^{n-N},
+
+ where :math:`n` is the number of trials, :math:`p` is the probability
+ of success, and :math:`N` is the number of successes.
+
+ When estimating the standard error of a proportion in a population by
+ using a random sample, the normal distribution works well unless the
+ product p*n <=5, where p = population proportion estimate, and n =
+ number of samples, in which case the binomial distribution is used
+ instead. For example, a sample of 15 people shows 4 who are left
+ handed, and 11 who are right handed. Then p = 4/15 = 27%. 0.27*15 = 4,
+ so the binomial distribution should be used in this case.
+
+ References
+ ----------
+ .. [1] Dalgaard, Peter, "Introductory Statistics with R",
+ Springer-Verlag, 2002.
+ .. [2] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill,
+ Fifth Edition, 2002.
+ .. [3] Lentner, Marvin, "Elementary Applied Statistics", Bogden
+ and Quigley, 1972.
+ .. [4] Weisstein, Eric W. "Binomial Distribution." From MathWorld--A
+ Wolfram Web Resource.
+ http://mathworld.wolfram.com/BinomialDistribution.html
+ .. [5] Wikipedia, "Binomial distribution",
+ https://en.wikipedia.org/wiki/Binomial_distribution
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> n, p = 10, .5 # number of trials, probability of each trial
+ >>> s = bm.random.binomial(n, p, 1000)
+ # result of flipping a coin 10 times, tested 1000 times.
+
+ A real world example. A company drills 9 wild-cat oil exploration
+ wells, each with an estimated probability of success of 0.1. All nine
+ wells fail. What is the probability of that happening?
+
+ Let's do 20,000 trials of the model, and count the number that
+ generate zero positive results.
+
+ >>> sum(bm.random.binomial(9, 0.1, 20000) == 0)/20000.
+ # answer = 0.38885, or 38%.
+ """
+ return DEFAULT.binomial(n, p, size, key=key)
+
+
+def chisquare(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a chi-square distribution.
+
+ When `df` independent random variables, each with standard normal
+ distributions (mean 0, variance 1), are squared and summed, the
+ resulting distribution is chi-square (see Notes). This distribution
+ is often used in hypothesis testing.
+
+ Parameters
+ ----------
+ df : float or array_like of floats
+ Number of degrees of freedom, must be > 0.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``df`` is a scalar. Otherwise,
+ ``np.array(df).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized chi-square distribution.
+
+ Raises
+ ------
+ ValueError
+ When `df` <= 0 or when an inappropriate `size` (e.g. ``size=-1``)
+ is given.
+
+ Notes
+ -----
+ The variable obtained by summing the squares of `df` independent,
+ standard normally distributed random variables:
+
+ .. math:: Q = \sum_{i=0}^{\mathtt{df}} X^2_i
+
+ is chi-square distributed, denoted
+
+ .. math:: Q \sim \chi^2_k.
+
+ The probability density function of the chi-squared distribution is
+
+ .. math:: p(x) = \frac{(1/2)^{k/2}}{\Gamma(k/2)}
+ x^{k/2 - 1} e^{-x/2},
+
+ where :math:`\Gamma` is the gamma function,
+
+ .. math:: \Gamma(x) = \int_0^{-\infty} t^{x - 1} e^{-t} dt.
+
+ References
+ ----------
+ .. [1] NIST "Engineering Statistics Handbook"
+ https://www.itl.nist.gov/div898/handbook/eda/section3/eda3666.htm
+
+ Examples
+ --------
+ >>> bm.random.chisquare(2,4)
+ array([ 1.89920014, 9.00867716, 3.13710533, 5.62318272]) # random
+ """
+ return DEFAULT.chisquare(df, size, key=key)
+
+
+def dirichlet(alpha, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from the Dirichlet distribution.
+
+ Draw `size` samples of dimension k from a Dirichlet distribution. A
+ Dirichlet-distributed random variable can be seen as a multivariate
+ generalization of a Beta distribution. The Dirichlet distribution
+ is a conjugate prior of a multinomial distribution in Bayesian
+ inference.
+
+ Parameters
+ ----------
+ alpha : sequence of floats, length k
+ Parameter of the distribution (length ``k`` for sample of
+ length ``k``).
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n)``, then
+ ``m * n * k`` samples are drawn. Default is None, in which case a
+ vector of length ``k`` is returned.
+
+ Returns
+ -------
+ samples : ndarray,
+ The drawn samples, of shape ``(size, k)``.
+
+ Raises
+ ------
+ ValueError
+ If any value in ``alpha`` is less than or equal to zero
+
+ Notes
+ -----
+ The Dirichlet distribution is a distribution over vectors
+ :math:`x` that fulfil the conditions :math:`x_i>0` and
+ :math:`\sum_{i=1}^k x_i = 1`.
+
+ The probability density function :math:`p` of a
+ Dirichlet-distributed random vector :math:`X` is
+ proportional to
+
+ .. math:: p(x) \propto \prod_{i=1}^{k}{x^{\alpha_i-1}_i},
+
+ where :math:`\alpha` is a vector containing the positive
+ concentration parameters.
+
+ The method uses the following property for computation: let :math:`Y`
+ be a random vector which has components that follow a standard gamma
+ distribution, then :math:`X = \frac{1}{\sum_{i=1}^k{Y_i}} Y`
+ is Dirichlet-distributed
+
+ References
+ ----------
+ .. [1] David McKay, "Information Theory, Inference and Learning
+ Algorithms," chapter 23,
+ http://www.inference.org.uk/mackay/itila/
+ .. [2] Wikipedia, "Dirichlet distribution",
+ https://en.wikipedia.org/wiki/Dirichlet_distribution
+
+ Examples
+ --------
+ Taking an example cited in Wikipedia, this distribution can be used if
+ one wanted to cut strings (each of initial length 1.0) into K pieces
+ with different lengths, where each piece had, on average, a designated
+ average length, but allowing some variation in the relative sizes of
+ the pieces.
+
+ >>> s = bm.random.dirichlet((10, 5, 3), 20).transpose()
+
+ >>> import matplotlib.pyplot as plt
+ >>> plt.barh(range(20), s[0])
+ >>> plt.barh(range(20), s[1], left=s[0], color='g')
+ >>> plt.barh(range(20), s[2], left=s[0]+s[1], color='r')
+ >>> plt.title("Lengths of Strings")
+ """
+ return DEFAULT.dirichlet(alpha, size, key=key)
+
+
+def geometric(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from the geometric distribution.
+
+ Bernoulli trials are experiments with one of two outcomes:
+ success or failure (an example of such an experiment is flipping
+ a coin). The geometric distribution models the number of trials
+ that must be run in order to achieve success. It is therefore
+ supported on the positive integers, ``k = 1, 2, ...``.
+
+ The probability mass function of the geometric distribution is
+
+ .. math:: f(k) = (1 - p)^{k - 1} p
+
+ where `p` is the probability of success of an individual trial.
+
+ Parameters
+ ----------
+ p : float or array_like of floats
+ The probability of success of an individual trial.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``p`` is a scalar. Otherwise,
+ ``np.array(p).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized geometric distribution.
+
+ Examples
+ --------
+ Draw ten thousand values from the geometric distribution,
+ with the probability of an individual success equal to 0.35:
+
+ >>> z = bm.random.geometric(p=0.35, size=10000)
+
+ How many trials succeeded after a single run?
+
+ >>> (z == 1).sum() / 10000.
+ 0.34889999999999999 #random
+ """
+ return DEFAULT.geometric(p, size, key=key)
+
+
+def f(dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from an F distribution.
+
+ Samples are drawn from an F distribution with specified parameters,
+ `dfnum` (degrees of freedom in numerator) and `dfden` (degrees of
+ freedom in denominator), where both parameters must be greater than
+ zero.
+
+ The random variate of the F distribution (also known as the
+ Fisher distribution) is a continuous probability distribution
+ that arises in ANOVA tests, and is the ratio of two chi-square
+ variates.
+
+ Parameters
+ ----------
+ dfnum : float or array_like of floats
+ Degrees of freedom in numerator, must be > 0.
+ dfden : float or array_like of float
+ Degrees of freedom in denominator, must be > 0.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``dfnum`` and ``dfden`` are both scalars.
+ Otherwise, ``np.broadcast(dfnum, dfden).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized Fisher distribution.
+
+ See Also
+ --------
+ scipy.stats.f : probability density function, distribution or
+ cumulative density function, etc.
+
+ Notes
+ -----
+ The F statistic is used to compare in-group variances to between-group
+ variances. Calculating the distribution depends on the sampling, and
+ so it is a function of the respective degrees of freedom in the
+ problem. The variable `dfnum` is the number of samples minus one, the
+ between-groups degrees of freedom, while `dfden` is the within-groups
+ degrees of freedom, the sum of the number of samples in each group
+ minus the number of groups.
+
+ References
+ ----------
+ .. [1] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill,
+ Fifth Edition, 2002.
+ .. [2] Wikipedia, "F-distribution",
+ https://en.wikipedia.org/wiki/F-distribution
+
+ Examples
+ --------
+ An example from Glantz[1], pp 47-40:
+
+ Two groups, children of diabetics (25 people) and children from people
+ without diabetes (25 controls). Fasting blood glucose was measured,
+ case group had a mean value of 86.1, controls had a mean value of
+ 82.2. Standard deviations were 2.09 and 2.49 respectively. Are these
+ data consistent with the null hypothesis that the parents diabetic
+ status does not affect their children's blood glucose levels?
+ Calculating the F statistic from the data gives a value of 36.01.
+
+ Draw samples from the distribution:
+
+ >>> dfnum = 1. # between group degrees of freedom
+ >>> dfden = 48. # within groups degrees of freedom
+ >>> s = bm.random.f(dfnum, dfden, 1000)
+
+ The lower bound for the top 1% of the samples is :
+
+ >>> np.sort(s)[-10]
+ 7.61988120985 # random
+
+ So there is about a 1% chance that the F statistic will exceed 7.62,
+ the measured value is 36, so the null hypothesis is rejected at the 1%
+ level.
+ """
+ return DEFAULT.f(dfnum, dfden, size, key=key)
+
+
+def hypergeometric(ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Hypergeometric distribution.
+
+ Samples are drawn from a hypergeometric distribution with specified
+ parameters, `ngood` (ways to make a good selection), `nbad` (ways to make
+ a bad selection), and `nsample` (number of items sampled, which is less
+ than or equal to the sum ``ngood + nbad``).
+
+ Parameters
+ ----------
+ ngood : int or array_like of ints
+ Number of ways to make a good selection. Must be nonnegative.
+ nbad : int or array_like of ints
+ Number of ways to make a bad selection. Must be nonnegative.
+ nsample : int or array_like of ints
+ Number of items sampled. Must be at least 1 and at most
+ ``ngood + nbad``.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if `ngood`, `nbad`, and `nsample`
+ are all scalars. Otherwise, ``np.broadcast(ngood, nbad, nsample).size``
+ samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized hypergeometric distribution. Each
+ sample is the number of good items within a randomly selected subset of
+ size `nsample` taken from a set of `ngood` good items and `nbad` bad items.
+
+ See Also
+ --------
+ scipy.stats.hypergeom : probability density function, distribution or
+ cumulative density function, etc.
+
+ Notes
+ -----
+ The probability density for the Hypergeometric distribution is
+
+ .. math:: P(x) = \frac{\binom{g}{x}\binom{b}{n-x}}{\binom{g+b}{n}},
+
+ where :math:`0 \le x \le n` and :math:`n-b \le x \le g`
+
+ for P(x) the probability of ``x`` good results in the drawn sample,
+ g = `ngood`, b = `nbad`, and n = `nsample`.
+
+ Consider an urn with black and white marbles in it, `ngood` of them
+ are black and `nbad` are white. If you draw `nsample` balls without
+ replacement, then the hypergeometric distribution describes the
+ distribution of black balls in the drawn sample.
+
+ Note that this distribution is very similar to the binomial
+ distribution, except that in this case, samples are drawn without
+ replacement, whereas in the Binomial case samples are drawn with
+ replacement (or the sample space is infinite). As the sample space
+ becomes large, this distribution approaches the binomial.
+
+ References
+ ----------
+ .. [1] Lentner, Marvin, "Elementary Applied Statistics", Bogden
+ and Quigley, 1972.
+ .. [2] Weisstein, Eric W. "Hypergeometric Distribution." From
+ MathWorld--A Wolfram Web Resource.
+ http://mathworld.wolfram.com/HypergeometricDistribution.html
+ .. [3] Wikipedia, "Hypergeometric distribution",
+ https://en.wikipedia.org/wiki/Hypergeometric_distribution
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> ngood, nbad, nsamp = 100, 2, 10
+ # number of good, number of bad, and number of samples
+ >>> s = bm.random.hypergeometric(ngood, nbad, nsamp, 1000)
+ >>> from matplotlib.pyplot import hist
+ >>> hist(s)
+ # note that it is very unlikely to grab both bad items
+
+ Suppose you have an urn with 15 white and 15 black marbles.
+ If you pull 15 marbles at random, how likely is it that
+ 12 or more of them are one color?
+
+ >>> s = bm.random.hypergeometric(15, 15, 15, 100000)
+ >>> sum(s>=12)/100000. + sum(s<=3)/100000.
+ # answer = 0.003 ... pretty unlikely!
+ """
+ return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key)
+
+
+def logseries(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a logarithmic series distribution.
+
+ Samples are drawn from a log series distribution with specified
+ shape parameter, 0 <= ``p`` < 1.
+
+ Parameters
+ ----------
+ p : float or array_like of floats
+ Shape parameter for the distribution. Must be in the range [0, 1).
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``p`` is a scalar. Otherwise,
+ ``np.array(p).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized logarithmic series distribution.
+
+ See Also
+ --------
+ scipy.stats.logser : probability density function, distribution or
+ cumulative density function, etc.
+
+ Notes
+ -----
+ The probability density for the Log Series distribution is
+
+ .. math:: P(k) = \frac{-p^k}{k \ln(1-p)},
+
+ where p = probability.
+
+ The log series distribution is frequently used to represent species
+ richness and occurrence, first proposed by Fisher, Corbet, and
+ Williams in 1943 [2]. It may also be used to model the numbers of
+ occupants seen in cars [3].
+
+ References
+ ----------
+ .. [1] Buzas, Martin A.; Culver, Stephen J., Understanding regional
+ species diversity through the log series distribution of
+ occurrences: BIODIVERSITY RESEARCH Diversity & Distributions,
+ Volume 5, Number 5, September 1999 , pp. 187-195(9).
+ .. [2] Fisher, R.A,, A.S. Corbet, and C.B. Williams. 1943. The
+ relation between the number of species and the number of
+ individuals in a random sample of an animal population.
+ Journal of Animal Ecology, 12:42-58.
+ .. [3] D. J. Hand, F. Daly, D. Lunn, E. Ostrowski, A Handbook of Small
+ Data Sets, CRC Press, 1994.
+ .. [4] Wikipedia, "Logarithmic distribution",
+ https://en.wikipedia.org/wiki/Logarithmic_distribution
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> a = .6
+ >>> s = bm.random.logseries(a, 10000)
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s)
+
+ # plot against distribution
+
+ >>> def logseries(k, p):
+ ... return -p**k/(k*np.log(1-p))
+ >>> plt.plot(bins, logseries(bins, a)*count.max()/
+ ... logseries(bins, a).max(), 'r')
+ >>> plt.show()
+ """
+ return DEFAULT.logseries(p, size, key=key)
+
+
+def multinomial(n, pvals, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a multinomial distribution.
+
+ The multinomial distribution is a multivariate generalization of the
+ binomial distribution. Take an experiment with one of ``p``
+ possible outcomes. An example of such an experiment is throwing a dice,
+ where the outcome can be 1 through 6. Each sample drawn from the
+ distribution represents `n` such experiments. Its values,
+ ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the
+ outcome was ``i``.
+
+ Parameters
+ ----------
+ n : int
+ Number of experiments.
+ pvals : sequence of floats, length p
+ Probabilities of each of the ``p`` different outcomes. These
+ must sum to 1 (however, the last element is always assumed to
+ account for the remaining probability, as long as
+ ``sum(pvals[:-1]) <= 1)``.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. Default is None, in which case a
+ single value is returned.
+
+ Returns
+ -------
+ out : ndarray
+ The drawn samples, of shape *size*, if that was provided. If not,
+ the shape is ``(N,)``.
+
+ In other words, each entry ``out[i,j,...,:]`` is an N-dimensional
+ value drawn from the distribution.
+
+ Examples
+ --------
+ Throw a dice 20 times:
+
+ >>> bm.random.multinomial(20, [1/6.]*6, size=1)
+ array([[4, 1, 7, 5, 2, 1]]) # random
+
+ It landed 4 times on 1, once on 2, etc.
+
+ Now, throw the dice 20 times, and 20 times again:
+
+ >>> bm.random.multinomial(20, [1/6.]*6, size=2)
+ array([[3, 4, 3, 3, 4, 3], # random
+ [2, 4, 3, 4, 0, 7]])
+
+ For the first run, we threw 3 times 1, 4 times 2, etc. For the second,
+ we threw 2 times 1, 4 times 2, etc.
+
+ A loaded die is more likely to land on number 6:
+
+ >>> bm.random.multinomial(100, [1/7.]*5 + [2/7.])
+ array([11, 16, 14, 17, 16, 26]) # random
+
+ The probability inputs should be normalized. As an implementation
+ detail, the value of the last entry is ignored and assumed to take
+ up any leftover probability mass, but this should not be relied on.
+ A biased coin which has twice as much weight on one side as on the
+ other should be sampled like so:
+
+ >>> bm.random.multinomial(100, [1.0 / 3, 2.0 / 3]) # RIGHT
+ array([38, 62]) # random
+
+ not like:
+
+ >>> bm.random.multinomial(100, [1.0, 2.0]) # WRONG
+ Traceback (most recent call last):
+ ValueError: pvals < 0, pvals > 1 or pvals contains NaNs
+ """
+ return DEFAULT.multinomial(n, pvals, size, key=key)
+
+
+def multivariate_normal(mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky',
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw random samples from a multivariate normal distribution.
+
+ The multivariate normal, multinormal or Gaussian distribution is a
+ generalization of the one-dimensional normal distribution to higher
+ dimensions. Such a distribution is specified by its mean and
+ covariance matrix. These parameters are analogous to the mean
+ (average or "center") and variance (standard deviation, or "width,"
+ squared) of the one-dimensional normal distribution.
+
+ Parameters
+ ----------
+ mean : 1-D array_like, of length N
+ Mean of the N-dimensional distribution.
+ cov : 2-D array_like, of shape (N, N)
+ Covariance matrix of the distribution. It must be symmetric and
+ positive-semidefinite for proper sampling.
+ size : int or tuple of ints, optional
+ Given a shape of, for example, ``(m,n,k)``, ``m*n*k`` samples are
+ generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because
+ each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``.
+ If no shape is specified, a single (`N`-D) sample is returned.
+ check_valid : { 'warn', 'raise', 'ignore' }, optional
+ Behavior when the covariance matrix is not positive semidefinite.
+ tol : float, optional
+ Tolerance when checking the singular values in covariance matrix.
+ cov is cast to double before the check.
+
+ Returns
+ -------
+ out : ndarray
+ The drawn samples, of shape *size*, if that was provided. If not,
+ the shape is ``(N,)``.
+
+ In other words, each entry ``out[i,j,...,:]`` is an N-dimensional
+ value drawn from the distribution.
+
+ Notes
+ -----
+ The mean is a coordinate in N-dimensional space, which represents the
+ location where samples are most likely to be generated. This is
+ analogous to the peak of the bell curve for the one-dimensional or
+ univariate normal distribution.
+
+ Covariance indicates the level to which two variables vary together.
+ From the multivariate normal distribution, we draw N-dimensional
+ samples, :math:`X = [x_1, x_2, ... x_N]`. The covariance matrix
+ element :math:`C_{ij}` is the covariance of :math:`x_i` and :math:`x_j`.
+ The element :math:`C_{ii}` is the variance of :math:`x_i` (i.e. its
+ "spread").
+
+ Instead of specifying the full covariance matrix, popular
+ approximations include:
+
+ - Spherical covariance (`cov` is a multiple of the identity matrix)
+ - Diagonal covariance (`cov` has non-negative elements, and only on
+ the diagonal)
+
+ This geometrical property can be seen in two dimensions by plotting
+ generated data-points:
+
+ >>> mean = [0, 0]
+ >>> cov = [[1, 0], [0, 100]] # diagonal covariance
+
+ Diagonal covariance means that points are oriented along x or y-axis:
+
+ >>> import matplotlib.pyplot as plt
+ >>> x, y = bm.random.multivariate_normal(mean, cov, 5000).T
+ >>> plt.plot(x, y, 'x')
+ >>> plt.axis('equal')
+ >>> plt.show()
+
+ Note that the covariance matrix must be positive semidefinite (a.k.a.
+ nonnegative-definite). Otherwise, the behavior of this method is
+ undefined and backwards compatibility is not guaranteed.
+
+ References
+ ----------
+ .. [1] Papoulis, A., "Probability, Random Variables, and Stochastic
+ Processes," 3rd ed., New York: McGraw-Hill, 1991.
+ .. [2] Duda, R. O., Hart, P. E., and Stork, D. G., "Pattern
+ Classification," 2nd ed., New York: Wiley, 2001.
+
+ Examples
+ --------
+ >>> mean = (1, 2)
+ >>> cov = [[1, 0], [0, 1]]
+ >>> x = bm.random.multivariate_normal(mean, cov, (3, 3))
+ >>> x.shape
+ (3, 3, 2)
+
+ Here we generate 800 samples from the bivariate normal distribution
+ with mean [0, 0] and covariance matrix [[6, -3], [-3, 3.5]]. The
+ expected variances of the first and second components of the sample
+ are 6 and 3.5, respectively, and the expected correlation
+ coefficient is -3/sqrt(6*3.5) ≈ -0.65465.
+
+ >>> cov = np.array([[6, -3], [-3, 3.5]])
+ >>> pts = bm.random.multivariate_normal([0, 0], cov, size=800)
+
+ Check that the mean, covariance, and correlation coefficient of the
+ sample are close to the expected values:
+
+ >>> pts.mean(axis=0)
+ array([ 0.0326911 , -0.01280782]) # may vary
+ >>> np.cov(pts.T)
+ array([[ 5.96202397, -2.85602287],
+ [-2.85602287, 3.47613949]]) # may vary
+ >>> np.corrcoef(pts.T)[0, 1]
+ -0.6273591314603949 # may vary
+
+ We can visualize this data with a scatter plot. The orientation
+ of the point cloud illustrates the negative correlation of the
+ components of this sample.
+
+ >>> import matplotlib.pyplot as plt
+ >>> plt.plot(pts[:, 0], pts[:, 1], '.', alpha=0.5)
+ >>> plt.axis('equal')
+ >>> plt.grid()
+ >>> plt.show()
+ """
+ return DEFAULT.multivariate_normal(mean, cov, size, method, key=key)
+
+
+def negative_binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a negative binomial distribution.
+
+ Samples are drawn from a negative binomial distribution with specified
+ parameters, `n` successes and `p` probability of success where `n`
+ is > 0 and `p` is in the interval [0, 1].
+
+ Parameters
+ ----------
+ n : float or array_like of floats
+ Parameter of the distribution, > 0.
+ p : float or array_like of floats
+ Parameter of the distribution, >= 0 and <=1.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``n`` and ``p`` are both scalars.
+ Otherwise, ``np.broadcast(n, p).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized negative binomial distribution,
+ where each sample is equal to N, the number of failures that
+ occurred before a total of n successes was reached.
+
+ Notes
+ -----
+ The probability mass function of the negative binomial distribution is
+
+ .. math:: P(N;n,p) = \frac{\Gamma(N+n)}{N!\Gamma(n)}p^{n}(1-p)^{N},
+
+ where :math:`n` is the number of successes, :math:`p` is the
+ probability of success, :math:`N+n` is the number of trials, and
+ :math:`\Gamma` is the gamma function. When :math:`n` is an integer,
+ :math:`\frac{\Gamma(N+n)}{N!\Gamma(n)} = \binom{N+n-1}{N}`, which is
+ the more common form of this term in the pmf. The negative
+ binomial distribution gives the probability of N failures given n
+ successes, with a success on the last trial.
+
+ If one throws a die repeatedly until the third time a "1" appears,
+ then the probability distribution of the number of non-"1"s that
+ appear before the third "1" is a negative binomial distribution.
+
+ References
+ ----------
+ .. [1] Weisstein, Eric W. "Negative Binomial Distribution." From
+ MathWorld--A Wolfram Web Resource.
+ http://mathworld.wolfram.com/NegativeBinomialDistribution.html
+ .. [2] Wikipedia, "Negative binomial distribution",
+ https://en.wikipedia.org/wiki/Negative_binomial_distribution
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ A real world example. A company drills wild-cat oil
+ exploration wells, each with an estimated probability of
+ success of 0.1. What is the probability of having one success
+ for each successive well, that is what is the probability of a
+ single success after drilling 5 wells, after 6 wells, etc.?
+
+ >>> s = bm.random.negative_binomial(1, 0.1, 100000)
+ >>> for i in range(1, 11): # doctest: +SKIP
+ ... probability = sum(s 0.
+ nonc : float or array_like of floats
+ Non-centrality, must be non-negative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``df`` and ``nonc`` are both scalars.
+ Otherwise, ``np.broadcast(df, nonc).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized noncentral chi-square distribution.
+
+ Notes
+ -----
+ The probability density function for the noncentral Chi-square
+ distribution is
+
+ .. math:: P(x;df,nonc) = \sum^{\infty}_{i=0}
+ \frac{e^{-nonc/2}(nonc/2)^{i}}{i!}
+ P_{Y_{df+2i}}(x),
+
+ where :math:`Y_{q}` is the Chi-square with q degrees of freedom.
+
+ References
+ ----------
+ .. [1] Wikipedia, "Noncentral chi-squared distribution"
+ https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution
+
+ Examples
+ --------
+ Draw values from the distribution and plot the histogram
+
+ >>> import matplotlib.pyplot as plt
+ >>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000),
+ ... bins=200, density=True)
+ >>> plt.show()
+
+ Draw values from a noncentral chisquare with very small noncentrality,
+ and compare to a chisquare.
+
+ >>> plt.figure()
+ >>> values = plt.hist(bm.random.noncentral_chisquare(3, .0000001, 100000),
+ ... bins=np.arange(0., 25, .1), density=True)
+ >>> values2 = plt.hist(bm.random.chisquare(3, 100000),
+ ... bins=np.arange(0., 25, .1), density=True)
+ >>> plt.plot(values[1][0:-1], values[0]-values2[0], 'ob')
+ >>> plt.show()
+
+ Demonstrate how large values of non-centrality lead to a more symmetric
+ distribution.
+
+ >>> plt.figure()
+ >>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000),
+ ... bins=200, density=True)
+ >>> plt.show()
+ """
+ return DEFAULT.noncentral_chisquare(df, nonc, size, key=key)
+
+
+def noncentral_f(dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from the noncentral F distribution.
+
+ Samples are drawn from an F distribution with specified parameters,
+ `dfnum` (degrees of freedom in numerator) and `dfden` (degrees of
+ freedom in denominator), where both parameters > 1.
+ `nonc` is the non-centrality parameter.
+
+ Parameters
+ ----------
+ dfnum : float or array_like of floats
+ Numerator degrees of freedom, must be > 0.
+ dfden : float or array_like of floats
+ Denominator degrees of freedom, must be > 0.
+ nonc : float or array_like of floats
+ Non-centrality parameter, the sum of the squares of the numerator
+ means, must be >= 0.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``dfnum``, ``dfden``, and ``nonc``
+ are all scalars. Otherwise, ``np.broadcast(dfnum, dfden, nonc).size``
+ samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized noncentral Fisher distribution.
+
+ Notes
+ -----
+ When calculating the power of an experiment (power = probability of
+ rejecting the null hypothesis when a specific alternative is true) the
+ non-central F statistic becomes important. When the null hypothesis is
+ true, the F statistic follows a central F distribution. When the null
+ hypothesis is not true, then it follows a non-central F statistic.
+
+ References
+ ----------
+ .. [1] Weisstein, Eric W. "Noncentral F-Distribution."
+ From MathWorld--A Wolfram Web Resource.
+ http://mathworld.wolfram.com/NoncentralF-Distribution.html
+ .. [2] Wikipedia, "Noncentral F-distribution",
+ https://en.wikipedia.org/wiki/Noncentral_F-distribution
+
+ Examples
+ --------
+ In a study, testing for a specific alternative to the null hypothesis
+ requires use of the Noncentral F distribution. We need to calculate the
+ area in the tail of the distribution that exceeds the value of the F
+ distribution for the null hypothesis. We'll plot the two probability
+ distributions for comparison.
+
+ >>> dfnum = 3 # between group deg of freedom
+ >>> dfden = 20 # within groups degrees of freedom
+ >>> nonc = 3.0
+ >>> nc_vals = bm.random.noncentral_f(dfnum, dfden, nonc, 1000000)
+ >>> NF = np.histogram(nc_vals, bins=50, density=True)
+ >>> c_vals = bm.random.f(dfnum, dfden, 1000000)
+ >>> F = np.histogram(c_vals, bins=50, density=True)
+ >>> import matplotlib.pyplot as plt
+ >>> plt.plot(F[1][1:], F[0])
+ >>> plt.plot(NF[1][1:], NF[0])
+ >>> plt.show()
+ """
+ return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key)
+
+
+def power(a,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draws samples in [0, 1] from a power distribution with positive
+ exponent a - 1.
+
+ Also known as the power function distribution.
+
+ Parameters
+ ----------
+ a : float or array_like of floats
+ Parameter of the distribution. Must be non-negative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``a`` is a scalar. Otherwise,
+ ``np.array(a).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized power distribution.
+
+ Raises
+ ------
+ ValueError
+ If a <= 0.
+
+ Notes
+ -----
+ The probability density function is
+
+ .. math:: P(x; a) = ax^{a-1}, 0 \le x \le 1, a>0.
+
+ The power function distribution is just the inverse of the Pareto
+ distribution. It may also be seen as a special case of the Beta
+ distribution.
+
+ It is used, for example, in modeling the over-reporting of insurance
+ claims.
+
+ References
+ ----------
+ .. [1] Christian Kleiber, Samuel Kotz, "Statistical size distributions
+ in economics and actuarial sciences", Wiley, 2003.
+ .. [2] Heckert, N. A. and Filliben, James J. "NIST Handbook 148:
+ Dataplot Reference Manual, Volume 2: Let Subcommands and Library
+ Functions", National Institute of Standards and Technology
+ Handbook Series, June 2003.
+ https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/powpdf.pdf
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> a = 5. # shape
+ >>> samples = 1000
+ >>> s = bm.random.power(a, samples)
+
+ Display the histogram of the samples, along with
+ the probability density function:
+
+ >>> import matplotlib.pyplot as plt
+ >>> count, bins, ignored = plt.hist(s, bins=30)
+ >>> x = np.linspace(0, 1, 100)
+ >>> y = a*x**(a-1.)
+ >>> normed_y = samples*np.diff(bins)[0]*y
+ >>> plt.plot(x, normed_y)
+ >>> plt.show()
+
+ Compare the power function distribution to the inverse of the Pareto.
+
+ >>> from scipy import stats # doctest: +SKIP
+ >>> rvs = bm.random.power(5, 1000000)
+ >>> rvsp = bm.random.pareto(5, 1000000)
+ >>> xx = np.linspace(0,1,100)
+ >>> powpdf = stats.powerlaw.pdf(xx,5) # doctest: +SKIP
+
+ >>> plt.figure()
+ >>> plt.hist(rvs, bins=50, density=True)
+ >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP
+ >>> plt.title('bm.random.power(5)')
+
+ >>> plt.figure()
+ >>> plt.hist(1./(1.+rvsp), bins=50, density=True)
+ >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP
+ >>> plt.title('inverse of 1 + bm.random.pareto(5)')
+
+ >>> plt.figure()
+ >>> plt.hist(1./(1.+rvsp), bins=50, density=True)
+ >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP
+ >>> plt.title('inverse of stats.pareto(5)')
+ """
+ return DEFAULT.power(a, size, key=key)
+
+
+def rayleigh(scale=1.0,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Rayleigh distribution.
+
+ The :math:`\chi` and Weibull distributions are generalizations of the
+ Rayleigh.
+
+ Parameters
+ ----------
+ scale : float or array_like of floats, optional
+ Scale, also equals the mode. Must be non-negative. Default is 1.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``scale`` is a scalar. Otherwise,
+ ``np.array(scale).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized Rayleigh distribution.
+
+ Notes
+ -----
+ The probability density function for the Rayleigh distribution is
+
+ .. math:: P(x;scale) = \frac{x}{scale^2}e^{\frac{-x^2}{2 \cdotp scale^2}}
+
+ The Rayleigh distribution would arise, for example, if the East
+ and North components of the wind velocity had identical zero-mean
+ Gaussian distributions. Then the wind speed would have a Rayleigh
+ distribution.
+
+ References
+ ----------
+ .. [1] Brighton Webs Ltd., "Rayleigh Distribution,"
+ https://web.archive.org/web/20090514091424/http://brighton-webs.co.uk:80/distributions/rayleigh.asp
+ .. [2] Wikipedia, "Rayleigh distribution"
+ https://en.wikipedia.org/wiki/Rayleigh_distribution
+
+ Examples
+ --------
+ Draw values from the distribution and plot the histogram
+
+ >>> from matplotlib.pyplot import hist
+ >>> values = hist(bm.random.rayleigh(3, 100000), bins=200, density=True)
+
+ Wave heights tend to follow a Rayleigh distribution. If the mean wave
+ height is 1 meter, what fraction of waves are likely to be larger than 3
+ meters?
+
+ >>> meanvalue = 1
+ >>> modevalue = np.sqrt(2 / np.pi) * meanvalue
+ >>> s = bm.random.rayleigh(modevalue, 1000000)
+
+ The percentage of waves larger than 3 meters is:
+
+ >>> 100.*sum(s>3)/1000000.
+ 0.087300000000000003 # random
+ """
+ return DEFAULT.rayleigh(scale, size, key=key)
+
+
+def triangular(size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from the triangular distribution over the
+ interval ``[left, right]``.
+
+ The triangular distribution is a continuous probability
+ distribution with lower limit left, peak at mode, and upper
+ limit right. Unlike the other distributions, these parameters
+ directly define the shape of the pdf.
+
+ Parameters
+ ----------
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``left``, ``mode``, and ``right``
+ are all scalars. Otherwise, ``np.broadcast(left, mode, right).size``
+ samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized triangular distribution.
+
+ Notes
+ -----
+ The probability density function for the triangular distribution is
+
+ .. math:: P(x;l, m, r) = \begin{cases}
+ \frac{2(x-l)}{(r-l)(m-l)}& \text{for $l \leq x \leq m$},\\
+ \frac{2(r-x)}{(r-l)(r-m)}& \text{for $m \leq x \leq r$},\\
+ 0& \text{otherwise}.
+ \end{cases}
+
+ The triangular distribution is often used in ill-defined
+ problems where the underlying distribution is not known, but
+ some knowledge of the limits and mode exists. Often it is used
+ in simulations.
+
+ References
+ ----------
+ .. [1] Wikipedia, "Triangular distribution"
+ https://en.wikipedia.org/wiki/Triangular_distribution
+
+ Examples
+ --------
+ Draw values from the distribution and plot the histogram:
+
+ >>> import matplotlib.pyplot as plt
+ >>> h = plt.hist(bm.random.triangular(-3, 0, 8, 100000), bins=200,
+ ... density=True)
+ >>> plt.show()
+ """
+ return DEFAULT.triangular(size, key=key)
+
+
+def vonmises(mu,
+ kappa,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a von Mises distribution.
+
+ Samples are drawn from a von Mises distribution with specified mode
+ (mu) and dispersion (kappa), on the interval [-pi, pi].
+
+ The von Mises distribution (also known as the circular normal
+ distribution) is a continuous probability distribution on the unit
+ circle. It may be thought of as the circular analogue of the normal
+ distribution.
+
+ Parameters
+ ----------
+ mu : float or array_like of floats
+ Mode ("center") of the distribution.
+ kappa : float or array_like of floats
+ Dispersion of the distribution, has to be >=0.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``mu`` and ``kappa`` are both scalars.
+ Otherwise, ``np.broadcast(mu, kappa).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized von Mises distribution.
+
+ See Also
+ --------
+ scipy.stats.vonmises : probability density function, distribution, or
+ cumulative density function, etc.
+
+ Notes
+ -----
+ The probability density for the von Mises distribution is
+
+ .. math:: p(x) = \frac{e^{\kappa cos(x-\mu)}}{2\pi I_0(\kappa)},
+
+ where :math:`\mu` is the mode and :math:`\kappa` the dispersion,
+ and :math:`I_0(\kappa)` is the modified Bessel function of order 0.
+
+ The von Mises is named for Richard Edler von Mises, who was born in
+ Austria-Hungary, in what is now the Ukraine. He fled to the United
+ States in 1939 and became a professor at Harvard. He worked in
+ probability theory, aerodynamics, fluid mechanics, and philosophy of
+ science.
+
+ References
+ ----------
+ .. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of
+ Mathematical Functions with Formulas, Graphs, and Mathematical
+ Tables, 9th printing," New York: Dover, 1972.
+ .. [2] von Mises, R., "Mathematical Theory of Probability
+ and Statistics", New York: Academic Press, 1964.
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> mu, kappa = 0.0, 4.0 # mean and dispersion
+ >>> s = bm.random.vonmises(mu, kappa, 1000)
+
+ Display the histogram of the samples, along with
+ the probability density function:
+
+ >>> import matplotlib.pyplot as plt
+ >>> from scipy.special import i0 # doctest: +SKIP
+ >>> plt.hist(s, 50, density=True)
+ >>> x = np.linspace(-np.pi, np.pi, num=51)
+ >>> y = np.exp(kappa*np.cos(x-mu))/(2*np.pi*i0(kappa)) # doctest: +SKIP
+ >>> plt.plot(x, y, linewidth=2, color='r') # doctest: +SKIP
+ >>> plt.show()
+ """
+ return DEFAULT.vonmises(mu, kappa, size, key=key)
+
+
+def wald(mean,
+ scale,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Wald, or inverse Gaussian, distribution.
+
+ As the scale approaches infinity, the distribution becomes more like a
+ Gaussian. Some references claim that the Wald is an inverse Gaussian
+ with mean equal to 1, but this is by no means universal.
+
+ The inverse Gaussian distribution was first studied in relationship to
+ Brownian motion. In 1956 M.C.K. Tweedie used the name inverse Gaussian
+ because there is an inverse relationship between the time to cover a
+ unit distance and distance covered in unit time.
+
+ Parameters
+ ----------
+ mean : float or array_like of floats
+ Distribution mean, must be > 0.
+ scale : float or array_like of floats
+ Scale parameter, must be > 0.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``mean`` and ``scale`` are both scalars.
+ Otherwise, ``np.broadcast(mean, scale).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized Wald distribution.
+
+ Notes
+ -----
+ The probability density function for the Wald distribution is
+
+ .. math:: P(x;mean,scale) = \sqrt{\frac{scale}{2\pi x^3}}e^
+ \frac{-scale(x-mean)^2}{2\cdotp mean^2x}
+
+ As noted above the inverse Gaussian distribution first arise
+ from attempts to model Brownian motion. It is also a
+ competitor to the Weibull for use in reliability modeling and
+ modeling stock returns and interest rate processes.
+
+ References
+ ----------
+ .. [1] Brighton Webs Ltd., Wald Distribution,
+ https://web.archive.org/web/20090423014010/http://www.brighton-webs.co.uk:80/distributions/wald.asp
+ .. [2] Chhikara, Raj S., and Folks, J. Leroy, "The Inverse Gaussian
+ Distribution: Theory : Methodology, and Applications", CRC Press,
+ 1988.
+ .. [3] Wikipedia, "Inverse Gaussian distribution"
+ https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution
+
+ Examples
+ --------
+ Draw values from the distribution and plot the histogram:
+
+ >>> import matplotlib.pyplot as plt
+ >>> h = plt.hist(bm.random.wald(3, 2, 100000), bins=200, density=True)
+ >>> plt.show()
+ """
+ return DEFAULT.wald(mean, scale, size, key=key)
+
+
+def weibull(a,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ r"""
+ Draw samples from a Weibull distribution.
+
+ Draw samples from a 1-parameter Weibull distribution with the given
+ shape parameter `a`.
+
+ .. math:: X = (-ln(U))^{1/a}
+
+ Here, U is drawn from the uniform distribution over (0,1].
+
+ The more common 2-parameter Weibull, including a scale parameter
+ :math:`\lambda` is just :math:`X = \lambda(-ln(U))^{1/a}`.
+
+ .. note::
+ New code should use the ``weibull`` method of a ``default_rng()``
+ instance instead; please see the :ref:`random-quick-start`.
+
+ Parameters
+ ----------
+ a : float or array_like of floats
+ Shape parameter of the distribution. Must be nonnegative.
+ size : int or tuple of ints, optional
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. If size is ``None`` (default),
+ a single value is returned if ``a`` is a scalar. Otherwise,
+ ``np.array(a).size`` samples are drawn.
+
+ Returns
+ -------
+ out : ndarray or scalar
+ Drawn samples from the parameterized Weibull distribution.
+
+ Notes
+ -----
+ The Weibull (or Type III asymptotic extreme value distribution
+ for smallest values, SEV Type III, or Rosin-Rammler
+ distribution) is one of a class of Generalized Extreme Value
+ (GEV) distributions used in modeling extreme value problems.
+ This class includes the Gumbel and Frechet distributions.
+
+ The probability density for the Weibull distribution is
+
+ .. math:: p(x) = \frac{a}
+ {\lambda}(\frac{x}{\lambda})^{a-1}e^{-(x/\lambda)^a},
+
+ where :math:`a` is the shape and :math:`\lambda` the scale.
+
+ The function has its peak (the mode) at
+ :math:`\lambda(\frac{a-1}{a})^{1/a}`.
+
+ When ``a = 1``, the Weibull distribution reduces to the exponential
+ distribution.
+
+ References
+ ----------
+ .. [1] Waloddi Weibull, Royal Technical University, Stockholm,
+ 1939 "A Statistical Theory Of The Strength Of Materials",
+ Ingeniorsvetenskapsakademiens Handlingar Nr 151, 1939,
+ Generalstabens Litografiska Anstalts Forlag, Stockholm.
+ .. [2] Waloddi Weibull, "A Statistical Distribution Function of
+ Wide Applicability", Journal Of Applied Mechanics ASME Paper
+ 1951.
+ .. [3] Wikipedia, "Weibull distribution",
+ https://en.wikipedia.org/wiki/Weibull_distribution
+
+ Examples
+ --------
+ Draw samples from the distribution:
+
+ >>> a = 5. # shape
+ >>> s = brainpy.math.random.weibull(a, 1000)
+
+ Display the histogram of the samples, along with
+ the probability density function:
+
+ >>> import matplotlib.pyplot as plt
+ >>> x = np.arange(1,100.)/50.
+ >>> def weib(x,n,a):
+ ... return (a / n) * (x / n)**(a - 1) * np.exp(-(x / n)**a)
>>> count, bins, ignored = plt.hist(brainpy.math.random.weibull(5.,1000))
>>> x = np.arange(1,100.)/50.
@@ -2152,7 +4515,10 @@ def weibull(a, size=None, key=None):
return DEFAULT.weibull(a, size, key=key)
-def weibull_min(a, scale=None, size=None, key=None):
+def weibull_min(a,
+ scale=None,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample from a Weibull distribution.
The scipy counterpart is `scipy.stats.weibull_min`.
@@ -2171,7 +4537,9 @@ def weibull_min(a, scale=None, size=None, key=None):
return DEFAULT.weibull_min(a, scale, size, key=key)
-def zipf(a, size=None, key=None):
+def zipf(a,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Zipf distribution.
@@ -2206,7 +4574,6 @@ def zipf(a, size=None, key=None):
--------
scipy.stats.zipf : probability density function, distribution, or
cumulative density function, etc.
- random.Generator.zipf: which should be used for new code.
Notes
-----
@@ -2259,7 +4626,8 @@ def zipf(a, size=None, key=None):
return DEFAULT.zipf(a, size, key=key)
-def maxwell(size=None, key=None):
+def maxwell(size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample from a one sided Maxwell distribution.
The scipy counterpart is `scipy.stats.maxwell`.
@@ -2276,7 +4644,9 @@ def maxwell(size=None, key=None):
return DEFAULT.maxwell(size, key=key)
-def t(df, size=None, key=None):
+def t(df,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample Student’s t random values.
Parameters
@@ -2295,7 +4665,9 @@ def t(df, size=None, key=None):
return DEFAULT.t(df, size, key=key)
-def orthogonal(n: int, size=None, key=None):
+def orthogonal(n: int,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample uniformly from the orthogonal group `O(n)`.
Parameters
@@ -2313,7 +4685,9 @@ def orthogonal(n: int, size=None, key=None):
return DEFAULT.orthogonal(n, size, key=key)
-def loggamma(a, size=None, key=None):
+def loggamma(a,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample log-gamma random values.
Parameters
@@ -2329,10 +4703,13 @@ def loggamma(a, size=None, key=None):
out: array_like
The sampled results.
"""
- return DEFAULT.loggamma(a, size)
+ return DEFAULT.loggamma(a, size, key=key)
-def categorical(logits, axis: int = -1, size=None, key=None):
+def categorical(logits,
+ axis: int = -1,
+ size: Optional[Union[int, Sequence[int]]] = None,
+ key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample random values from categorical distributions.
Args:
@@ -2351,9 +4728,9 @@ def categorical(logits, axis: int = -1, size=None, key=None):
return DEFAULT.categorical(logits, axis, size, key=key)
-def rand_like(input, *, dtype=None, key=None):
- """Similar to ``rand_like`` in torch.
-
+def rand_like(input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ """Similar to ``rand_like`` in torch.
+
Returns a tensor with the same size as input that is filled with random
numbers from a uniform distribution on the interval ``[0, 1)``.
@@ -2368,9 +4745,9 @@ def rand_like(input, *, dtype=None, key=None):
return DEFAULT.rand_like(input, dtype=dtype, key=key)
-def randn_like(input, *, dtype=None, key=None):
- """Similar to ``randn_like`` in torch.
-
+def randn_like(input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ """Similar to ``randn_like`` in torch.
+
Returns a tensor with the same size as ``input`` that is filled with
random numbers from a normal distribution with mean 0 and variance 1.
@@ -2385,9 +4762,9 @@ def randn_like(input, *, dtype=None, key=None):
return DEFAULT.randn_like(input, dtype=dtype, key=key)
-def randint_like(input, low=0, high=None, *, dtype=None, key=None):
- """Similar to ``randint_like`` in torch.
-
+def randint_like(input, low=0, high=None, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
+ """Similar to ``randint_like`` in torch.
+
Returns a tensor with the same shape as Tensor ``input`` filled with
random integers generated uniformly between ``low`` (inclusive) and ``high`` (exclusive).
diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py
index 0acd2010b..453ab387d 100644
--- a/brainpy/_src/math/sparse/_bsr_mm.py
+++ b/brainpy/_src/math/sparse/_bsr_mm.py
@@ -404,8 +404,8 @@ def _bcsrmm_cutlass_jvp_transpose():
_bcsrmm_cutlass_p.multiple_results = True
_bcsrmm_cutlass_p.def_abstract_eval(_bcsrmm_cutlass_abstract)
_bcsrmm_cutlass_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_p))
-xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_cpu_translation
-xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_gpu_translation
+# xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_cpu_translation
+# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_gpu_translation
ad.primitive_jvps[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose
ad.primitive_transposes[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose
register_general_batching(bcsrmm)
@@ -456,5 +456,5 @@ def _blocksparse_matmat_back_gpu_translation(
_bcsrmm_cutlass_back_p.multiple_results = True
_bcsrmm_cutlass_back_p.def_abstract_eval(_blocksparse_matmat_back_abstract)
_bcsrmm_cutlass_back_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_back_p))
-xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation
+# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation
register_general_batching(_bcsrmm_cutlass_back_p)
diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/_bsr_mv.py
index 76d1715e0..a35895bc1 100644
--- a/brainpy/_src/math/sparse/_bsr_mv.py
+++ b/brainpy/_src/math/sparse/_bsr_mv.py
@@ -202,8 +202,8 @@ def _cusparse_bcsr_transpose(ct, data, indices, indptr, vector, *, blocksize, sh
cusparse_bcsr_matvec_vector_p = Primitive('cusparse_block_spmv')
cusparse_bcsr_matvec_vector_p.def_abstract_eval(_cusparse_bcsr_matvec_abstract)
cusparse_bcsr_matvec_vector_p.def_impl(partial(xla.apply_primitive, cusparse_bcsr_matvec_vector_p))
-xla.backend_specific_translations['gpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_gpu_translation
-xla.backend_specific_translations['cpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_cpu_translation
+# xla.backend_specific_translations['gpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_gpu_translation
+# xla.backend_specific_translations['cpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_cpu_translation
ad.defjvp(cusparse_bcsr_matvec_vector_p, _cusparse_bcsr_matvec_jvp_values)
ad.primitive_transposes[cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_transpose
register_general_batching(cusparse_bcsr_matvec_vector_p)
diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py
index e29dbfb9b..377597579 100644
--- a/brainpy/_src/math/sparse/_csr_mv.py
+++ b/brainpy/_src/math/sparse/_csr_mv.py
@@ -13,234 +13,295 @@
from jax.lib import xla_client
from jaxlib import gpu_sparse
+from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (compile_cpu_signature_with_numba,
- register_general_batching)
+ register_general_batching,
+ XLACustomOp)
from brainpy._src.math.sparse._utils import csr_to_coo
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops
from brainpy.errors import GPUOperatorNotFound
+ti = import_taichi()
+
__all__ = [
- 'csrmv',
+ 'csrmv',
]
def csrmv(
- data: Union[float, jnp.ndarray, Array],
- indices: Union[jnp.ndarray, Array],
- indptr: Union[jnp.ndarray, Array],
- vector: Union[jnp.ndarray, Array],
- *,
- shape: Tuple[int, int],
- transpose: bool = False,
- method: str = 'cusparse',
+ data: Union[float, jnp.ndarray, Array],
+ indices: Union[jnp.ndarray, Array],
+ indptr: Union[jnp.ndarray, Array],
+ vector: Union[jnp.ndarray, Array],
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ method: str = None,
):
- """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm.
-
- This function supports JAX transformations, including `jit()`, `grad()`,
- `vmap()` and `pmap()`.
-
- Parameters
- ----------
- data: ndarray, float
- An array of shape ``(nse,)``.
- indices: ndarray
- An array of shape ``(nse,)``.
- indptr: ndarray
- An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
- vector: ndarray
- An array of shape ``(shape[0] if transpose else shape[1],)``
- and dtype ``data.dtype``.
- shape: tuple of int
- A length-2 tuple representing the matrix shape.
- transpose: bool
- A boolean specifying whether to transpose the sparse matrix
- before computing.
- method: str
- The method used to compute Matrix-Vector Multiplication. The candidate methods are:
-
- - ``cusparse``: using cuSPARSE library.
- - ``scalar``:
- - ``vector``:
- - ``adaptive``:
-
- Returns
- -------
- y : ndarry
- The array of shape ``(shape[1] if transpose else shape[0],)`` representing
- the matrix vector product.
- """
-
- data = jnp.atleast_1d(as_jax(data))
- indices = as_jax(indices)
- indptr = as_jax(indptr)
- vector = as_jax(vector)
-
- if vector.dtype == jnp.bool_:
- vector = as_jax(vector, dtype=data.dtype)
-
- if method == 'cusparse':
- if jax.default_backend() == 'gpu':
- if data.shape[0] == 1:
- data = jnp.ones(indices.shape, dtype=data.dtype) * data
- if indices.dtype in [jnp.uint32, jnp.uint64]:
- indices = jnp.asarray(indices, dtype=dtypes.canonicalize_dtype(jnp.int64))
- if indptr.dtype in [jnp.uint32, jnp.uint64]:
- indptr = jnp.asarray(indptr, dtype=dtypes.canonicalize_dtype(jnp.int64))
- return _csrmv_cusparse_p.bind(data,
- indices,
- indptr,
- vector,
- shape=shape,
- transpose=transpose)
-
- elif method == 'adaptive':
- return _csrmv_adaptive_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)
-
- elif method == 'scalar':
- return _csrmv_scalar_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)
-
- elif method == 'vector':
- return _csrmv_vector_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)
-
- else:
- raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.')
+ """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm.
+
+ This function supports JAX transformations, including `jit()`, `grad()`,
+ `vmap()` and `pmap()`.
+
+ Parameters
+ ----------
+ data: ndarray, float
+ An array of shape ``(nse,)``.
+ indices: ndarray
+ An array of shape ``(nse,)``.
+ indptr: ndarray
+ An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
+ vector: ndarray
+ An array of shape ``(shape[0] if transpose else shape[1],)``
+ and dtype ``data.dtype``.
+ shape: tuple of int
+ A length-2 tuple representing the matrix shape.
+ transpose: bool
+ A boolean specifying whether to transpose the sparse matrix
+ before computing.
+ method: str
+ The method used to compute Matrix-Vector Multiplication. Default is ``taichi``.
+ The candidate methods are:
+
+ - ``None``: default using Taichi kernel.
+ - ``cusparse``: using cuSPARSE library.
+ - ``scalar``:
+ - ``vector``:
+ - ``adaptive``:
+
+ Returns
+ -------
+ y : ndarry
+ The array of shape ``(shape[1] if transpose else shape[0],)`` representing
+ the matrix vector product.
+ """
+ if method is None:
+ return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)
+ else:
+ return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method)
+
+
+### BRAINPYLIB ###
+
+def csrmv_brainpylib(
+ data: Union[float, jnp.ndarray, Array],
+ indices: Union[jnp.ndarray, Array],
+ indptr: Union[jnp.ndarray, Array],
+ vector: Union[jnp.ndarray, Array],
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+ method: str = 'cusparse',
+):
+ """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm.
+
+ This function supports JAX transformations, including `jit()`, `grad()`,
+ `vmap()` and `pmap()`.
+
+ Parameters
+ ----------
+ data: ndarray, float
+ An array of shape ``(nse,)``.
+ indices: ndarray
+ An array of shape ``(nse,)``.
+ indptr: ndarray
+ An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
+ vector: ndarray
+ An array of shape ``(shape[0] if transpose else shape[1],)``
+ and dtype ``data.dtype``.
+ shape: tuple of int
+ A length-2 tuple representing the matrix shape.
+ transpose: bool
+ A boolean specifying whether to transpose the sparse matrix
+ before computing.
+ method: str
+ The method used to compute Matrix-Vector Multiplication. The candidate methods are:
+
+ - ``cusparse``: using cuSPARSE library.
+ - ``scalar``:
+ - ``vector``:
+ - ``adaptive``:
+
+ Returns
+ -------
+ y : ndarry
+ The array of shape ``(shape[1] if transpose else shape[0],)`` representing
+ the matrix vector product.
+ """
+
+ data = jnp.atleast_1d(as_jax(data))
+ indices = as_jax(indices)
+ indptr = as_jax(indptr)
+ vector = as_jax(vector)
+
+ if vector.dtype == jnp.bool_:
+ vector = as_jax(vector, dtype=data.dtype)
+
+ if method == 'cusparse':
+ if jax.default_backend() == 'gpu':
+ if data.shape[0] == 1:
+ data = jnp.ones(indices.shape, dtype=data.dtype) * data
+ if indices.dtype in [jnp.uint32, jnp.uint64]:
+ indices = jnp.asarray(indices, dtype=dtypes.canonicalize_dtype(jnp.int64))
+ if indptr.dtype in [jnp.uint32, jnp.uint64]:
+ indptr = jnp.asarray(indptr, dtype=dtypes.canonicalize_dtype(jnp.int64))
+ return _csrmv_cusparse_p.bind(data,
+ indices,
+ indptr,
+ vector,
+ shape=shape,
+ transpose=transpose)
+
+ elif method == 'adaptive':
+ return _csrmv_adaptive_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)
+
+ elif method == 'scalar':
+ return _csrmv_scalar_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)
+
+ elif method == 'vector':
+ return _csrmv_vector_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)
+
+ else:
+ raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.')
def _csrmv_abstract(data, indices, indptr, vector, *, shape, transpose):
- if data.dtype not in [jnp.float32, jnp.float64]:
- raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.')
- if data.dtype != vector.dtype:
- raise TypeError('The types of data and vector should be the same. '
- f'But we got {data.dtype} != {vector.dtype}.')
- assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1
- if not jnp.issubdtype(indices.dtype, jnp.integer):
- raise ValueError('indices should be a 1D vector with integer type.')
- if not jnp.issubdtype(indptr.dtype, jnp.integer):
- raise ValueError('indptr should be a 1D vector with integer type.')
- out_shape = shape[1] if transpose else shape[0]
- return core.ShapedArray((out_shape,), data.dtype)
+ if data.dtype not in [jnp.float32, jnp.float64]:
+ raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.')
+ if data.dtype != vector.dtype:
+ raise TypeError('The types of data and vector should be the same. '
+ f'But we got {data.dtype} != {vector.dtype}.')
+ assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1
+ if not jnp.issubdtype(indices.dtype, jnp.integer):
+ raise ValueError('indices should be a 1D vector with integer type.')
+ if not jnp.issubdtype(indptr.dtype, jnp.integer):
+ raise ValueError('indptr should be a 1D vector with integer type.')
+ out_shape = shape[1] if transpose else shape[0]
+ return core.ShapedArray((out_shape,), data.dtype)
@numba.njit(fastmath=True)
def _csr_matvec_transpose_numba_imp(outs, ins):
- res_val = outs
- res_val.fill(0)
- values, col_indices, row_ptr, vector, shape, _ = ins
- # (csr mat).T @ vec
-
- if values.shape[0] == 1:
- values = values[0]
- for row_i in range(shape[0]):
- v = vector[row_i]
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- res_val[col_indices[j]] += values * v
- else:
- for row_i in range(shape[0]):
- v = vector[row_i]
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- res_val[col_indices[j]] += v * values[j]
+ res_val = outs
+ res_val.fill(0)
+ values, col_indices, row_ptr, vector, shape, _ = ins
+ # (csr mat).T @ vec
+
+ if values.shape[0] == 1:
+ values = values[0]
+ for row_i in range(shape[0]):
+ v = vector[row_i]
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ res_val[col_indices[j]] += values * v
+ else:
+ for row_i in range(shape[0]):
+ v = vector[row_i]
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ res_val[col_indices[j]] += v * values[j]
@numba.njit(fastmath=True, parallel=True, nogil=True)
def _csr_matvec_numba_imp(outs, ins):
- res_val = outs
- res_val.fill(0)
- values, col_indices, row_ptr, vector, shape, _ = ins
- # csr mat @ vec
- if values.shape[0] == 1:
- values = values[0]
- for row_i in numba.prange(shape[0]):
- r = 0.
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- r += values * vector[col_indices[j]]
- res_val[row_i] = r
- else:
- for row_i in numba.prange(shape[0]):
- r = 0.
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- r += values[j] * vector[col_indices[j]]
- res_val[row_i] = r
+ res_val = outs
+ res_val.fill(0)
+ values, col_indices, row_ptr, vector, shape, _ = ins
+ # csr mat @ vec
+ if values.shape[0] == 1:
+ values = values[0]
+ for row_i in numba.prange(shape[0]):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += values * vector[col_indices[j]]
+ res_val[row_i] = r
+ else:
+ for row_i in numba.prange(shape[0]):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += values[j] * vector[col_indices[j]]
+ res_val[row_i] = r
def _csrmv_cpu_translation(c, data, indices, indptr, vector, *, shape, transpose):
- inputs = (data, indices, indptr, vector)
- description = dict(shape=shape, transpose=transpose)
- if transpose:
- target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba(
- c,
- _csr_matvec_transpose_numba_imp,
- _csrmv_abstract,
- multiple_results=False,
- inputs=inputs,
- description=description
- )
- else:
- target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba(
- c,
- _csr_matvec_numba_imp,
- _csrmv_abstract,
- multiple_results=False,
- inputs=inputs,
- description=description
- )
- return xla_client.ops.CustomCallWithLayout(
- c,
- target_name,
- operands=inputs,
- operand_shapes_with_layout=input_layouts,
- shape_with_layout=output_layouts,
+ inputs = (data, indices, indptr, vector)
+ description = dict(shape=shape, transpose=transpose)
+ if transpose:
+ target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba(
+ c,
+ _csr_matvec_transpose_numba_imp,
+ _csrmv_abstract,
+ multiple_results=False,
+ inputs=inputs,
+ description=description
)
+ else:
+ target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba(
+ c,
+ _csr_matvec_numba_imp,
+ _csrmv_abstract,
+ multiple_results=False,
+ inputs=inputs,
+ description=description
+ )
+ return xla_client.ops.CustomCallWithLayout(
+ c,
+ target_name,
+ operands=inputs,
+ operand_shapes_with_layout=input_layouts,
+ shape_with_layout=output_layouts,
+ )
def _csrmv_cusparse_gpu_lowering(ctx, data, indices, indptr, vector, *, shape, transpose):
- data_aval, indices_aval, _, v_aval = ctx.avals_in
- dtype = data_aval.dtype
- if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
- raise TypeError(f"cusparse_csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
- "Falling back to default implementation.")
- return [gpu_sparse.cuda_csr_matvec(data, indices, indptr, vector,
- shape=shape,
- transpose=transpose,
- data_dtype=dtype,
- x_dtype=v_aval.dtype,
- index_dtype=indices_aval.dtype)]
+ data_aval, indices_aval, _, v_aval = ctx.avals_in
+ dtype = data_aval.dtype
+ if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
+ raise TypeError(f"cusparse_csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
+ "Falling back to default implementation.")
+ return [gpu_sparse.cuda_csr_matvec(data, indices, indptr, vector,
+ shape=shape,
+ transpose=transpose,
+ data_dtype=dtype,
+ x_dtype=v_aval.dtype,
+ index_dtype=indices_aval.dtype)]
def _csrmv_jvp_mat(csr_prim, data_dot, data, indices, indptr, v, *, shape, transpose):
- return csr_prim.bind(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
+ return csr_prim.bind(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
def _csrmv_jvp_vec(prim, v_dot, data, indices, indptr, v, *, shape, transpose):
- return prim.bind(data, indices, indptr, v_dot, shape=shape, transpose=transpose)
+ return prim.bind(data, indices, indptr, v_dot, shape=shape, transpose=transpose)
def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
- if ad.is_undefined_primal(vector):
- ct_vector = _csrmv_cusparse_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
- return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+ if ad.is_undefined_primal(vector):
+ if type(ct) is ad.Zero:
+ return data, indices, indptr, ad.Zero(vector)
+ else:
+ ct_vector = _csrmv_cusparse_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
+ return data, indices, indptr, ct_vector
+ else:
+ if type(ct) is ad.Zero:
+ ct_data = ad.Zero(data)
else:
- if type(ct) is ad.Zero:
- ct_data = ad.Zero(data)
- else:
- if data.aval.shape[0] == 1: # scalar
- ct_data = _csrmv_cusparse_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
- ct_data = jnp.inner(ct, ct_data)
- else: # heterogeneous values
- row, col = csr_to_coo(indices, indptr)
- ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
- return ct_data, indices, indptr, vector
+ if data.aval.shape[0] == 1: # scalar
+ ct_data = _csrmv_cusparse_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
+ ct_data = jnp.inner(ct, ct_data)
+ else: # heterogeneous values
+ row, col = csr_to_coo(indices, indptr)
+ ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
+ return ct_data, indices, indptr, vector
_csrmv_cusparse_p = core.Primitive('cusparse_csr_matvec')
_csrmv_cusparse_p.def_abstract_eval(_csrmv_abstract)
_csrmv_cusparse_p.def_impl(partial(xla.apply_primitive, _csrmv_cusparse_p))
-xla.backend_specific_translations['cpu'][_csrmv_cusparse_p] = _csrmv_cpu_translation
+# xla.backend_specific_translations['cpu'][_csrmv_cusparse_p] = _csrmv_cpu_translation
ad.defjvp(_csrmv_cusparse_p,
partial(_csrmv_jvp_mat, _csrmv_cusparse_p),
None,
@@ -252,67 +313,67 @@ def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, trans
def _csr_matvec_scalar_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose):
- gpu_ops = import_brainpylib_gpu_ops()
- if gpu_ops is None:
- raise GPUOperatorNotFound(_csrmv_scalar_p.name)
- if transpose:
- raise NotImplementedError
-
- data_shape = c.get_shape(data)
- if data_shape.element_type() == np.float32:
- ftype = b'_float'
- elif data_shape.element_type() == np.float64:
- ftype = b'_double'
- else:
- raise ValueError
- indices_shape = c.get_shape(indices)
- if indices_shape.element_type() == np.int32:
- itype = b'_int'
- elif indices_shape.element_type() == np.int64:
- itype = b'_long'
- else:
- raise ValueError
- data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter'
- opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1])
- return xla_client.ops.CustomCallWithLayout(
- c,
- b'csrmv_' + data_name + b'_scalar' + ftype + itype,
- operands=(data, indices, indptr, vector),
- operand_shapes_with_layout=(c.get_shape(data),
- c.get_shape(indices),
- c.get_shape(indptr),
- c.get_shape(vector)),
- shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)),
- opaque=opaque,
- )
+ gpu_ops = import_brainpylib_gpu_ops()
+ if gpu_ops is None:
+ raise GPUOperatorNotFound(_csrmv_scalar_p.name)
+ if transpose:
+ raise NotImplementedError
+
+ data_shape = c.get_shape(data)
+ if data_shape.element_type() == np.float32:
+ ftype = b'_float'
+ elif data_shape.element_type() == np.float64:
+ ftype = b'_double'
+ else:
+ raise ValueError
+ indices_shape = c.get_shape(indices)
+ if indices_shape.element_type() == np.int32:
+ itype = b'_int'
+ elif indices_shape.element_type() == np.int64:
+ itype = b'_long'
+ else:
+ raise ValueError
+ data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter'
+ opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1])
+ return xla_client.ops.CustomCallWithLayout(
+ c,
+ b'csrmv_' + data_name + b'_scalar' + ftype + itype,
+ operands=(data, indices, indptr, vector),
+ operand_shapes_with_layout=(c.get_shape(data),
+ c.get_shape(indices),
+ c.get_shape(indptr),
+ c.get_shape(vector)),
+ shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)),
+ opaque=opaque,
+ )
def _csrmv_scalar_transpose(ct, data, indices, indptr, vector, *, shape, transpose):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
- if ad.is_undefined_primal(vector):
- ct_vector = _csrmv_scalar_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
- return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+ if ad.is_undefined_primal(vector):
+ ct_vector = _csrmv_scalar_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
+ return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+ else:
+ if type(ct) is ad.Zero:
+ ct_data = ad.Zero(data)
else:
- if type(ct) is ad.Zero:
- ct_data = ad.Zero(data)
- else:
- if data.aval.shape[0] == 1: # scalar
- ct_data = _csrmv_scalar_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
- ct_data = jnp.inner(ct, ct_data)
- else: # heterogeneous values
- row, col = csr_to_coo(indices, indptr)
- ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
- return ct_data, indices, indptr, vector
+ if data.aval.shape[0] == 1: # scalar
+ ct_data = _csrmv_scalar_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
+ ct_data = jnp.inner(ct, ct_data)
+ else: # heterogeneous values
+ row, col = csr_to_coo(indices, indptr)
+ ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
+ return ct_data, indices, indptr, vector
_csrmv_scalar_p = core.Primitive('csr_matvec_scalar')
_csrmv_scalar_p.def_abstract_eval(_csrmv_abstract)
_csrmv_scalar_p.def_impl(partial(xla.apply_primitive, _csrmv_scalar_p))
-xla.backend_specific_translations['cpu'][_csrmv_scalar_p] = _csrmv_cpu_translation
-xla.backend_specific_translations['gpu'][_csrmv_scalar_p] = _csr_matvec_scalar_gpu_translation
+# xla.backend_specific_translations['cpu'][_csrmv_scalar_p] = _csrmv_cpu_translation
+# xla.backend_specific_translations['gpu'][_csrmv_scalar_p] = _csr_matvec_scalar_gpu_translation
ad.defjvp(_csrmv_scalar_p,
partial(_csrmv_jvp_mat, _csrmv_scalar_p),
None,
@@ -323,67 +384,67 @@ def _csrmv_scalar_transpose(ct, data, indices, indptr, vector, *, shape, transpo
def _csr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose):
- gpu_ops = import_brainpylib_gpu_ops()
- if gpu_ops is None:
- raise GPUOperatorNotFound(_csrmv_vector_p.name)
- if transpose:
- raise NotImplementedError
-
- data_shape = c.get_shape(data)
- if data_shape.element_type() == np.float32:
- ftype = b'_float'
- elif data_shape.element_type() == np.float64:
- ftype = b'_double'
- else:
- raise ValueError
- indices_shape = c.get_shape(indices)
- if indices_shape.element_type() == np.int32:
- itype = b'_int'
- elif indices_shape.element_type() == np.int64:
- itype = b'_long'
- else:
- raise ValueError
- data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter'
- opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1])
- return xla_client.ops.CustomCallWithLayout(
- c,
- b'csrmv_' + data_name + b'_vector' + ftype + itype,
- operands=(data, indices, indptr, vector),
- operand_shapes_with_layout=(c.get_shape(data),
- c.get_shape(indices),
- c.get_shape(indptr),
- c.get_shape(vector)),
- shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)),
- opaque=opaque,
- )
+ gpu_ops = import_brainpylib_gpu_ops()
+ if gpu_ops is None:
+ raise GPUOperatorNotFound(_csrmv_vector_p.name)
+ if transpose:
+ raise NotImplementedError
+
+ data_shape = c.get_shape(data)
+ if data_shape.element_type() == np.float32:
+ ftype = b'_float'
+ elif data_shape.element_type() == np.float64:
+ ftype = b'_double'
+ else:
+ raise ValueError
+ indices_shape = c.get_shape(indices)
+ if indices_shape.element_type() == np.int32:
+ itype = b'_int'
+ elif indices_shape.element_type() == np.int64:
+ itype = b'_long'
+ else:
+ raise ValueError
+ data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter'
+ opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1])
+ return xla_client.ops.CustomCallWithLayout(
+ c,
+ b'csrmv_' + data_name + b'_vector' + ftype + itype,
+ operands=(data, indices, indptr, vector),
+ operand_shapes_with_layout=(c.get_shape(data),
+ c.get_shape(indices),
+ c.get_shape(indptr),
+ c.get_shape(vector)),
+ shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)),
+ opaque=opaque,
+ )
def _csrmv_vector_transpose(ct, data, indices, indptr, vector, *, shape, transpose):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
- if ad.is_undefined_primal(vector):
- ct_vector = _csrmv_vector_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
- return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+ if ad.is_undefined_primal(vector):
+ ct_vector = _csrmv_vector_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
+ return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+ else:
+ if type(ct) is ad.Zero:
+ ct_data = ad.Zero(data)
else:
- if type(ct) is ad.Zero:
- ct_data = ad.Zero(data)
- else:
- if data.aval.shape[0] == 1: # scalar
- ct_data = _csrmv_vector_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
- ct_data = jnp.inner(ct, ct_data)
- else: # heterogeneous values
- row, col = csr_to_coo(indices, indptr)
- ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
- return ct_data, indices, indptr, vector
+ if data.aval.shape[0] == 1: # scalar
+ ct_data = _csrmv_vector_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
+ ct_data = jnp.inner(ct, ct_data)
+ else: # heterogeneous values
+ row, col = csr_to_coo(indices, indptr)
+ ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
+ return ct_data, indices, indptr, vector
_csrmv_vector_p = core.Primitive('csr_matvec_vector')
_csrmv_vector_p.def_abstract_eval(_csrmv_abstract)
_csrmv_vector_p.def_impl(partial(xla.apply_primitive, _csrmv_vector_p))
-xla.backend_specific_translations['cpu'][_csrmv_vector_p] = _csrmv_cpu_translation
-xla.backend_specific_translations['gpu'][_csrmv_vector_p] = _csr_matvec_vector_gpu_translation
+# xla.backend_specific_translations['cpu'][_csrmv_vector_p] = _csrmv_cpu_translation
+# xla.backend_specific_translations['gpu'][_csrmv_vector_p] = _csr_matvec_vector_gpu_translation
ad.defjvp(_csrmv_vector_p,
partial(_csrmv_jvp_mat, _csrmv_vector_p),
None,
@@ -394,68 +455,68 @@ def _csrmv_vector_transpose(ct, data, indices, indptr, vector, *, shape, transpo
def _csr_matvec_adaptive_gpu_translation(c, data, indices, indptr, row_blocks, vector, *, shape, transpose):
- gpu_ops = import_brainpylib_gpu_ops()
- if gpu_ops is None:
- raise GPUOperatorNotFound(_csrmv_adaptive_p.name)
- if transpose:
- raise NotImplementedError
-
- data_shape = c.get_shape(data)
- if data_shape.element_type() == np.float32:
- ftype = b'_float'
- elif data_shape.element_type() == np.float64:
- ftype = b'_double'
- else:
- raise ValueError
- indices_shape = c.get_shape(indices)
- if indices_shape.element_type() == np.int32:
- itype = b'_int'
- elif indices_shape.element_type() == np.int64:
- itype = b'_long'
- else:
- raise ValueError
- data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter'
- opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1])
- return xla_client.ops.CustomCallWithLayout(
- c,
- b'csrmv_' + data_name + b'_vector' + ftype + itype,
- operands=(data, indices, indptr, row_blocks, vector),
- operand_shapes_with_layout=(c.get_shape(data),
- c.get_shape(indices),
- c.get_shape(indptr),
- c.get_shape(row_blocks),
- c.get_shape(vector)),
- shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)),
- opaque=opaque,
- )
+ gpu_ops = import_brainpylib_gpu_ops()
+ if gpu_ops is None:
+ raise GPUOperatorNotFound(_csrmv_adaptive_p.name)
+ if transpose:
+ raise NotImplementedError
+
+ data_shape = c.get_shape(data)
+ if data_shape.element_type() == np.float32:
+ ftype = b'_float'
+ elif data_shape.element_type() == np.float64:
+ ftype = b'_double'
+ else:
+ raise ValueError
+ indices_shape = c.get_shape(indices)
+ if indices_shape.element_type() == np.int32:
+ itype = b'_int'
+ elif indices_shape.element_type() == np.int64:
+ itype = b'_long'
+ else:
+ raise ValueError
+ data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter'
+ opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1])
+ return xla_client.ops.CustomCallWithLayout(
+ c,
+ b'csrmv_' + data_name + b'_vector' + ftype + itype,
+ operands=(data, indices, indptr, row_blocks, vector),
+ operand_shapes_with_layout=(c.get_shape(data),
+ c.get_shape(indices),
+ c.get_shape(indptr),
+ c.get_shape(row_blocks),
+ c.get_shape(vector)),
+ shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)),
+ opaque=opaque,
+ )
def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, transpose):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
- if ad.is_undefined_primal(vector):
- ct_vector = _csrmv_adaptive_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
- return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+ if ad.is_undefined_primal(vector):
+ ct_vector = _csrmv_adaptive_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose)
+ return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+ else:
+ if type(ct) is ad.Zero:
+ ct_data = ad.Zero(data)
else:
- if type(ct) is ad.Zero:
- ct_data = ad.Zero(data)
- else:
- if data.aval.shape[0] == 1: # scalar
- ct_data = _csrmv_adaptive_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
- ct_data = jnp.inner(ct, ct_data)
- else: # heterogeneous values
- row, col = csr_to_coo(indices, indptr)
- ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
- return ct_data, indices, indptr, vector
+ if data.aval.shape[0] == 1: # scalar
+ ct_data = _csrmv_adaptive_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
+ ct_data = jnp.inner(ct, ct_data)
+ else: # heterogeneous values
+ row, col = csr_to_coo(indices, indptr)
+ ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
+ return ct_data, indices, indptr, vector
_csrmv_adaptive_p = core.Primitive('csr_matvec_adaptive')
_csrmv_adaptive_p.def_abstract_eval(_csrmv_abstract)
_csrmv_adaptive_p.def_impl(partial(xla.apply_primitive, _csrmv_adaptive_p))
-xla.backend_specific_translations['cpu'][_csrmv_adaptive_p] = _csrmv_cpu_translation
-xla.backend_specific_translations['gpu'][_csrmv_adaptive_p] = _csr_matvec_adaptive_gpu_translation
+# xla.backend_specific_translations['cpu'][_csrmv_adaptive_p] = _csrmv_cpu_translation
+# xla.backend_specific_translations['gpu'][_csrmv_adaptive_p] = _csr_matvec_adaptive_gpu_translation
ad.defjvp(_csrmv_adaptive_p,
partial(_csrmv_jvp_mat, _csrmv_adaptive_p),
None,
@@ -463,3 +524,289 @@ def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, trans
partial(_csrmv_jvp_vec, _csrmv_adaptive_p), )
ad.primitive_transposes[_csrmv_adaptive_p] = _csrmv_adaptive_transpose
register_general_batching(_csrmv_adaptive_p)
+
+
+### TAICHI ###
+
+def csrmv_taichi(
+ data: Union[float, jnp.ndarray, Array],
+ indices: Union[jnp.ndarray, Array],
+ indptr: Union[jnp.ndarray, Array],
+ vector: Union[jnp.ndarray, Array],
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+) -> jax.Array:
+ """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm.
+
+ This function supports JAX transformations, including `jit()`, `grad()`,
+ `vmap()` and `pmap()`.
+
+ Parameters
+ ----------
+ data: ndarray, float
+ An array of shape ``(nse,)``.
+ indices: ndarray
+ An array of shape ``(nse,)``.
+ indptr: ndarray
+ An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
+ vector: ndarray
+ An array of shape ``(shape[0] if transpose else shape[1],)``
+ and dtype ``data.dtype``.
+ shape: tuple of int
+ A length-2 tuple representing the matrix shape.
+ transpose: bool
+ A boolean specifying whether to transpose the sparse matrix
+ before computing.
+
+ Returns
+ -------
+ y : ndarry
+ The array of shape ``(shape[1] if transpose else shape[0],)`` representing
+ the matrix vector product.
+ """
+
+ data = jnp.atleast_1d(as_jax(data))
+ indices = as_jax(indices)
+ indptr = as_jax(indptr)
+ vector = as_jax(vector)
+
+ if vector.dtype == jnp.bool_:
+ vector = as_jax(vector, dtype=data.dtype)
+
+ if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]:
+ raise TypeError('Only support float16, float32 or float64 type. '
+ f'But we got {data.dtype}.')
+ if data.dtype != vector.dtype:
+ raise TypeError('The types of data and vector should be the same. '
+ f'But we got {data.dtype} != {vector.dtype}.')
+ assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1
+ if not jnp.issubdtype(indices.dtype, jnp.integer):
+ raise ValueError('indices should be a 1D vector with integer type.')
+ if not jnp.issubdtype(indptr.dtype, jnp.integer):
+ raise ValueError('indptr should be a 1D vector with integer type.')
+
+ # if the shape of indices is (0,), then we return a zero vector
+ if indices.shape[0] == 0:
+ return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)
+
+ return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0]
+
+
+# -------------
+# CPU operators
+# -------------
+
+
+@ti.kernel
+def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ ti.loop_config(serialize=True)
+ for row_i in range(row_ptr.shape[0] - 1):
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ out[col_indices[j]] += value * vector[row_i]
+
+
+@ti.kernel
+def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ ti.loop_config(serialize=True)
+ for row_i in range(row_ptr.shape[0] - 1):
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ out[col_indices[j]] += vector[row_i] * values[j]
+
+
+@ti.kernel
+def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ # ti.loop_config(serialize=True)
+ for row_i in range(row_ptr.shape[0] - 1):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += vector[col_indices[j]]
+ out[row_i] = r * value
+
+
+@ti.kernel
+def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ # ti.loop_config(serialize=True)
+ for row_i in range(row_ptr.shape[0] - 1):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += values[j] * vector[col_indices[j]]
+ out[row_i] = r
+
+
+# -------------
+# GPU operators
+# -------------
+
+
+@ti.kernel
+def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ for i in range((row_ptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ j = row_ptr[row_i] + index
+ end_index = row_ptr[row_i + 1]
+ while j < end_index:
+ out[col_indices[j]] += value * vector[row_i]
+ j += 32
+
+
+@ti.kernel
+def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ value = values[0]
+ for i in range((row_ptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ r = 0.
+ j = row_ptr[row_i] + index
+ end_index = row_ptr[row_i + 1]
+ while j < end_index:
+ r += vector[col_indices[j]]
+ j += 32
+ out[row_i] += value * r
+
+
+@ti.kernel
+def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ for i in range((row_ptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ j = row_ptr[row_i] + index
+ end_index = row_ptr[row_i + 1]
+ while j < end_index:
+ out[col_indices[j]] += values[j] * vector[row_i]
+ j += 32
+
+
+@ti.kernel
+def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
+ col_indices: ti.types.ndarray(ndim=1),
+ row_ptr: ti.types.ndarray(ndim=1),
+ vector: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ for i in range((row_ptr.shape[0] - 1) * 32):
+ row_i = i >> 5
+ index = i & 31
+ r = 0.
+ j = row_ptr[row_i] + index
+ end_index = row_ptr[row_i + 1]
+ while j < end_index:
+ r += values[j] * vector[col_indices[j]]
+ j += 32
+ out[row_i] += r # TODO: warp-level primitive
+
+
+def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
+ return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose)
+
+
+def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
+ return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose)
+
+
+def _sparse_csr_matvec_transpose(
+ ct, data, indices, indptr, vector, *, outs, transpose, shape,
+):
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
+ if ad.is_undefined_primal(vector):
+ ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0]
+ return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector)
+
+ else:
+ if type(ct[0]) is ad.Zero:
+ ct_data = ad.Zero(data)
+ else:
+ if data.aval.shape[0] == 1: # scalar
+ ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0]
+ ct_data = jnp.inner(ct[0], ct_data)
+ else:
+ row, col = csr_to_coo(indices, indptr)
+ ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row]
+
+ return ct_data, indices, indptr, vector
+
+
+def raw_csrmv_taichi(
+ data: Union[float, jnp.ndarray, Array],
+ indices: Union[jnp.ndarray, Array],
+ indptr: Union[jnp.ndarray, Array],
+ vector: Union[jnp.ndarray, Array],
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+):
+ out_shape = shape[1] if transpose else shape[0]
+ if transpose:
+ if data.shape[0] == 1:
+ prim = _csr_matvec_transpose_homo_p
+ else:
+ prim = _csr_matvec_transpose_heter_p
+ else:
+ if data.shape[0] == 1:
+ prim = _csr_matvec_homo_p
+ else:
+ prim = _csr_matvec_heter_p
+
+ return prim(data,
+ indices,
+ indptr,
+ vector,
+ outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)],
+ transpose=transpose,
+ shape=shape)
+
+
+def _define_op(cpu_kernel, gpu_kernel):
+ prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
+ prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector)
+ prim.def_transpose_rule(_sparse_csr_matvec_transpose)
+ return prim
+
+
+# transpose homo
+_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu,
+ gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu)
+
+# no transpose homo
+_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu,
+ gpu_kernel=_sparse_csr_matvec_homo_gpu)
+
+# transpose heter
+_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu,
+ gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu)
+
+# no transpose heter
+_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu,
+ gpu_kernel=_sparse_csr_matvec_heter_gpu)
diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py
new file mode 100644
index 000000000..1db246212
--- /dev/null
+++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py
@@ -0,0 +1,250 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('cpu')
+
+s = [1000, 5000, 10000, 15000, 20000, 25000, 30000]
+p = [0.1, 0.2, 0.3, 0.4, 0.5]
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+]
+
+values_type = [
+ 'homo',
+ 'heter'
+ ]
+events_type = ['float']
+transpose = [
+ True,
+ False
+ ]
+method = 'cusparse'
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+print(bm.get_platform())
+
+@partial(jax.jit, static_argnums=(4, 5))
+def csrmv_taichi(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0]
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5))
+def csrmv(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)
+ return r
+
+def test_sparse_csrmv(shape, values_type, events_type, transpose):
+ rng = bm.random.RandomState(seed=1234)
+ indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
+ vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ weight = 1.
+
+
+ if events_type == 'float':
+ vector = vector.astype(bm.float32)
+ if values_type == 'heter':
+ heter_data = bm.ones(indices.shape) * weight
+ weight = heter_data
+
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/csrmv_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/csrmv_gpu.csv', index=False)
diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py
new file mode 100644
index 000000000..d902c9395
--- /dev/null
+++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py
@@ -0,0 +1,273 @@
+# from jax_taichi import jax_taichi_call
+
+import time
+from functools import partial
+import os
+
+import brainpy as bp
+import brainpy.math as bm
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import taichi as ti
+
+bm.set_platform('cpu')
+
+s = [1000,
+ 5000,
+ 10000,
+ 15000,
+ 20000,
+ 25000,
+ 30000]
+p = [0.1, 0.2, 0.3, 0.4, 0.5]
+
+shape = [
+ 1000,
+ 2500,
+ 5000,
+ 10000,
+ 25000,
+ 37500,
+ 50000
+]
+
+values_type = [
+ 'homo',
+ 'heter'
+ ]
+events_type = ['float']
+transpose = [
+ True,
+ False
+ ]
+method = 'cusparse'
+
+ITERATION = 100
+if bm.get_platform() == 'cpu':
+ ITERATION = 10
+
+print(bm.get_platform())
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)
+ return r.sum()
+
+ return func
+
+
+def sum_op2(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)[0]
+ return r.sum()
+
+ return func
+
+@partial(jax.jit, static_argnums=(4, 5))
+def csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)(
+ weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ return r
+
+@partial(jax.jit, static_argnums=(4, 5))
+def csrmv_grad(weight, indices, indptr, vector, shape, transpose):
+ r = 0
+ for i in range(ITERATION):
+ r += jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
+ weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ return r
+
+def test_sparse_csrmv(shape, values_type, events_type, transpose):
+ rng = bm.random.RandomState(seed=1234)
+ indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
+ vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
+ weight = 1.
+
+
+ if events_type == 'float':
+ vector = vector.astype(bm.float32)
+ if values_type == 'heter':
+ heter_data = bm.ones(indices.shape) * weight
+ weight = heter_data
+
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time0 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time1 = time.time()
+
+ time2 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time3 = time.time()
+
+ time4 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time5 = time.time()
+
+ time6 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time7 = time.time()
+
+ time8 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time9 = time.time()
+
+ time10 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time11 = time.time()
+
+ time12 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time13 = time.time()
+
+ time14 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time15 = time.time()
+
+ time16 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time17 = time.time()
+
+ time18 = time.time()
+ result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time19 = time.time()
+
+
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+
+ time20 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time21 = time.time()
+
+ time22 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time23 = time.time()
+
+ time24 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time25 = time.time()
+
+ time26 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time27 = time.time()
+
+ time28 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time29 = time.time()
+
+ time30 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time31 = time.time()
+
+ time32 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time33 = time.time()
+
+ time34 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time35 = time.time()
+
+ time36 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time37 = time.time()
+
+ time38 = time.time()
+ result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
+ time39 = time.time()
+
+ taichi_aot_time1 = (time1 - time0) * 1000
+ taichi_aot_time2 = (time3 - time2) * 1000
+ taichi_aot_time3 = (time5 - time4) * 1000
+ taichi_aot_time4 = (time7 - time6) * 1000
+ taichi_aot_time5 = (time9 - time8) * 1000
+ taichi_aot_time6 = (time11 - time10) * 1000
+ taichi_aot_time7 = (time13 - time12) * 1000
+ taichi_aot_time8 = (time15 - time14) * 1000
+ taichi_aot_time9 = (time17 - time16) * 1000
+ taichi_aot_time10 = (time19 - time18) * 1000
+ brainpy_time1 = (time21 - time20) * 1000
+ brainpy_time2 = (time23 - time22) * 1000
+ brainpy_time3 = (time25 - time24) * 1000
+ brainpy_time4 = (time27 - time26) * 1000
+ brainpy_time5 = (time29 - time28) * 1000
+ brainpy_time6 = (time31 - time30) * 1000
+ brainpy_time7 = (time33 - time32) * 1000
+ brainpy_time8 = (time35 - time34) * 1000
+ brainpy_time9 = (time37 - time36) * 1000
+ brainpy_time10 = (time39 - time38) * 1000
+ print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
+ print('taichi_aot_1: ', taichi_aot_time1, 'ms')
+ print('taichi_aot_3: ', taichi_aot_time3, 'ms')
+ print('taichi_aot_5: ', taichi_aot_time5, 'ms')
+ print('taichi_aot_7: ', taichi_aot_time7, 'ms')
+ print('taichi_aot_9: ', taichi_aot_time9, 'ms')
+ print('brainpylib_1: ', brainpy_time1, 'ms')
+ print('brainpylib_3: ', brainpy_time3, 'ms')
+ print('brainpylib_5: ', brainpy_time5, 'ms')
+ print('brainpylib_7: ', brainpy_time7, 'ms')
+ print('brainpylib_9: ', brainpy_time9, 'ms')
+
+
+ return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
+ taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
+ brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
+ brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
+
+PATH = os.path.dirname(os.path.abspath(__file__))
+
+# init dataframe
+df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
+ 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
+ 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
+ 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
+ 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
+
+
+### RECTANGULAR MATRIX
+if (bm.get_platform() == 'cpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/csrmv_grad_cpu.csv', index=False)
+
+if (bm.get_platform() == 'gpu'):
+ for shape1 in shape:
+ for shape2 in shape:
+ for _values_type in values_type:
+ for _events_type in events_type:
+ for _transpose in transpose:
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
+ # append to dataframe
+ df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
+ taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
+ taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
+ brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
+ brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
+ df.to_csv(f'{PATH}/csrmv_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py
index 16bf43a48..2c75f0901 100644
--- a/brainpy/_src/math/sparse/tests/test_csrmv.py
+++ b/brainpy/_src/math/sparse/tests/test_csrmv.py
@@ -3,24 +3,60 @@
from functools import partial
import jax
-import pytest
from absl.testing import parameterized
-import platform
+
import brainpy as bp
import brainpy.math as bm
-is_manual_test = False
-if platform.system() == 'Windows' and not is_manual_test:
- pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
+# bm.set_platform('gpu')
+
+seed = 1234
+
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)
+ return r.sum()
+
+ return func
+
+
+
+def compare_with_nan_tolerance(a, b, tol=1e-8):
+ """
+ Compare two arrays with tolerance for NaN values.
+
+ Parameters:
+ a (np.array): First array to compare.
+ b (np.array): Second array to compare.
+ tol (float): Tolerance for comparing non-NaN elements.
+
+ Returns:
+ bool: True if arrays are similar within the tolerance, False otherwise.
+ """
+ if a.shape != b.shape:
+ return False
+
+ # Create masks for NaNs in both arrays
+ nan_mask_a = bm.isnan(a)
+ nan_mask_b = bm.isnan(b)
+
+ # Check if NaN positions are the same in both arrays
+ if not bm.array_equal(nan_mask_a, nan_mask_b):
+ return False
+
+ # Compare non-NaN elements
+ a_non_nan = a[~nan_mask_a]
+ b_non_nan = b[~nan_mask_b]
-cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse')
-scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar')
-vector_csr_matvec = partial(bm.sparse.csrmv, method='vector')
+ return bm.allclose(a_non_nan, b_non_nan, atol=tol)
-class Test_cusparse_csrmv(parameterized.TestCase):
+taichi_csr_matvec = bm.sparse.csrmv
+
+class Test_csrmv_taichi(parameterized.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_cusparse_csrmv, self).__init__(*args, **kwargs)
+ super(Test_csrmv_taichi, self).__init__(*args, **kwargs)
print()
bm.set_platform(platform)
@@ -31,35 +67,36 @@ def __init__(self, *args, platform='cpu', **kwargs):
homo_data=[-1., 0., 1.]
)
def test_homo(self, transpose, shape, homo_data):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
+ print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}')
+ conn = bp.conn.FixedProb(0.3)
+ # matrix
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
-
- heter_data = bm.ones(indices.shape).value * homo_data
-
+ # vector
+ rng = bm.random.RandomState(seed=seed)
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)
- r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose)
- r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r2))
+
+ heter_data = bm.ones(indices.shape).value * homo_data
dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r3 = (vector @ dense) if transpose else (dense @ vector)
- self.assertTrue(bm.allclose(r1, r3))
+ r1 = (vector @ dense) if transpose else (dense @ vector)
+ r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r2))
bm.clear_buffer_memory()
@parameterized.product(
transpose=[True, False],
- shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)],
+ shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)],
v=[-1., 0., 1.]
)
def test_homo_vmap(self, transpose, shape, v):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
+ print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}')
+ rng = bm.random.RandomState(seed=seed)
+ conn = bp.conn.FixedProb(0.3)
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
@@ -71,17 +108,13 @@ def test_homo_vmap(self, transpose, shape, v):
homo_data = bm.ones(10).value * v
dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data)
- f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector,
+ f1 = lambda a: (a.T @ vector) if transpose else (a @ vector)
+ f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector,
shape=shape, transpose=transpose)
- f2 = lambda a: (a.T @ vector) if transpose else (a @ vector)
-
- r1 = jax.vmap(f1)(homo_data)
- r2 = jax.vmap(f1)(heter_data)
+ r1 = jax.vmap(f1)(dense_data)
+ r2 = jax.vmap(f2)(homo_data)
self.assertTrue(bm.allclose(r1, r2))
- r3 = jax.vmap(f2)(dense_data)
- self.assertTrue(bm.allclose(r1, r3))
-
bm.clear_buffer_memory()
@parameterized.product(
@@ -90,8 +123,9 @@ def test_homo_vmap(self, transpose, shape, v):
homo_data=[-1., 0., 1.]
)
def test_homo_grad(self, transpose, shape, homo_data):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
+ print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}')
+ rng = bm.random.RandomState(seed=seed)
+ conn = bp.conn.FixedProb(0.3)
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
@@ -103,37 +137,35 @@ def test_homo_grad(self, transpose, shape, homo_data):
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)
- csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector,
- shape=shape, transpose=transpose).sum(),
- argnums=0)
+ # print('grad data start')
+ # grad 'data'
dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum()
if transpose else
((dense * a) @ vector).sum()),
argnums=0)
+ r1 = dense_f1(homo_data)
+ r2 = jax.grad(sum_op(taichi_csr_matvec))(
+ homo_data, indices, indptr, vector, shape=shape, transpose=transpose)
- r1 = csr_f1(homo_data)
- r2 = dense_f1(homo_data)
self.assertTrue(bm.allclose(r1, r2))
- csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v,
- shape=shape, transpose=transpose).sum())
+ # print('grad vector start')
+ # grad 'vector'
dense_data = dense * homo_data
dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()))
+ r3 = dense_f2(vector)
+ r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)(
+ homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- r3 = csr_f2(vector)
- r4 = dense_f2(vector)
self.assertTrue(bm.allclose(r3, r4))
- csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v,
- shape=shape, transpose=transpose).sum(),
- argnums=(0, 1))
dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum()
if transpose else
((dense * a) @ v).sum()),
argnums=(0, 1))
-
- r5 = csr_f3(homo_data, vector)
- r6 = dense_f3(homo_data, vector)
+ r5 = dense_f3(homo_data, vector)
+ r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))(
+ homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r5[0], r6[0]))
self.assertTrue(bm.allclose(r5[1], r6[1]))
@@ -141,26 +173,28 @@ def test_homo_grad(self, transpose, shape, homo_data):
@parameterized.product(
transpose=[True, False],
- shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)],
+ shape=[(200, 200), (200, 100), (2, 2000)],
)
def test_heter(self, transpose, shape):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
+ print(f'test_homo: transpose = {transpose} shape = {shape}')
+ rng = bm.random.RandomState(seed=seed)
+ conn = bp.conn.FixedProb(0.3)
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
- heter_data = rng.random(indices.shape)
+ heter_data = bm.as_jax(rng.random(indices.shape))
heter_data = bm.as_jax(heter_data)
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)
- r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector,
- shape=shape, transpose=transpose)
+
dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r2 = (vector @ dense) if transpose else (dense @ vector)
- self.assertTrue(bm.allclose(r1, r2))
+ r1 = (vector @ dense) if transpose else (dense @ vector)
+ r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose)
+
+ self.assertTrue(compare_with_nan_tolerance(r1, r2))
bm.clear_buffer_memory()
@@ -169,8 +203,8 @@ def test_heter(self, transpose, shape):
shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
)
def test_heter_vmap(self, transpose, shape):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
+ rng = bm.random.RandomState(seed=seed)
+ conn = bp.conn.FixedProb(0.3)
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
@@ -183,23 +217,20 @@ def test_heter_vmap(self, transpose, shape):
dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr,
shape=shape))(heter_data)
- f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector,
+ f1 = lambda a: (a.T @ vector) if transpose else (a @ vector)
+ f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector,
shape=shape, transpose=transpose)
- f2 = lambda a: (a.T @ vector) if transpose else (a @ vector)
-
- r1 = jax.vmap(f1)(heter_data)
- r2 = jax.vmap(f2)(dense_data)
- self.assertTrue(bm.allclose(r1, r2))
-
- bm.clear_buffer_memory()
+ r1 = jax.vmap(f1)(dense_data)
+ r2 = jax.vmap(f2)(heter_data)
+ self.assertTrue(compare_with_nan_tolerance(r1, r2))
@parameterized.product(
transpose=[True, False],
shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
)
def test_heter_grad(self, transpose, shape):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
+ rng = bm.random.RandomState(seed=seed)
+ conn = bp.conn.FixedProb(0.3)
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
@@ -210,141 +241,29 @@ def test_heter_grad(self, transpose, shape):
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)
- csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector,
+ # grad 'data'
+ dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()),
+ argnums=0)
+ csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector,
shape=shape,
transpose=transpose).sum(),
argnums=0)
- dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()),
- argnums=0)
-
r1 = csr_f1(heter_data)
r2 = dense_f1(dense_data)
rows, cols = bm.sparse.csr_to_coo(indices, indptr)
r2 = r2[rows, cols]
+ print(r1.shape, r2.shape)
self.assertTrue(bm.allclose(r1, r2))
- csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v,
- shape=shape,
- transpose=transpose).sum(),
- argnums=0)
+ # grad 'vector'
dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()),
argnums=0)
- r3 = csr_f2(vector)
- r4 = dense_f2(vector)
+ csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v,
+ shape=shape,
+ transpose=transpose).sum(),
+ argnums=0)
+ r3 = dense_f2(vector)
+ r4 = csr_f2(vector)
self.assertTrue(bm.allclose(r3, r4))
bm.clear_buffer_memory()
-
-
-class Test_csrmv(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_csrmv, self).__init__(*args, **kwargs)
-
- print()
- bm.set_platform(platform)
-
- @parameterized.product(
- homo_data=[-1., 0., 0.1, 1.],
- shape=[(100, 200), (10, 1000), (2, 2000)],
- )
- def test_homo(self, shape, homo_data):
- conn = bp.conn.FixedProb(0.1)
-
- # matrix
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- # vector
- rng = bm.random.RandomState(123)
- vector = rng.random(shape[1])
- vector = bm.as_jax(vector)
-
- # csrmv
- r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape)
- r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape)
- r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape)
- self.assertTrue(bm.allclose(r1, r2))
- self.assertTrue(bm.allclose(r1, r3))
-
- heter_data = bm.ones(indices.shape).to_jax() * homo_data
- r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
- r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
- r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
- self.assertTrue(bm.allclose(r1, r4))
- self.assertTrue(bm.allclose(r1, r5))
- self.assertTrue(bm.allclose(r1, r6))
-
- dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- rdense = dense @ vector
- self.assertTrue(bm.allclose(r1, rdense))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)]
- )
- def test_heter(self, shape):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
-
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- heter_data = bm.as_jax(rng.random(indices.shape))
- vector = bm.as_jax(rng.random(shape[1]))
-
- r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
- r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
- r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
-
- dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r4 = dense @ vector
- self.assertTrue(bm.allclose(r1, r2))
- self.assertTrue(bm.allclose(r1, r3))
- self.assertTrue(bm.allclose(r1, r4))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
- )
- def test_heter_grad(self, shape):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
-
- indices, indptr = conn(*shape).require('pre2post')
- heter_data = rng.random(indices.shape)
- dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- vector = rng.random(shape[1])
-
- csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum())
- csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum())
- csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum())
- dense_f1 = jax.grad(lambda a: (a @ vector).sum())
-
- r1 = csr_f1(heter_data)
- r2 = csr_f2(heter_data)
- r3 = csr_f3(heter_data)
-
- d1 = dense_f1(dense_data)
- rows, cols = bm.sparse.csr_to_coo(indices, indptr)
- d1 = d1[rows, cols]
- self.assertTrue(bm.allclose(r1, r2))
- self.assertTrue(bm.allclose(r1, r3))
- self.assertTrue(bm.allclose(r1, d1))
-
- # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum())
- # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum())
- # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum())
- # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum())
- # r4 = csr_f4(vector)
- # r5 = csr_f5(vector)
- # r6 = csr_f6(vector)
- # d2 = dense_f2(vector)
- # self.assertTrue(bm.allclose(r4, r5))
- # self.assertTrue(bm.allclose(r4, r6))
- # self.assertTrue(bm.allclose(r4, d2))
-
- bm.clear_buffer_memory()
-
-
diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py b/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py
deleted file mode 100644
index ccf090ec4..000000000
--- a/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import jax
-import pytest
-
-import test_csrmv
-
-if jax.default_backend() != 'gpu':
- pytest.skip("No gpu available.", allow_module_level=True)
-
-
-class Test_cusparse_csrmv_GPU(test_csrmv.Test_cusparse_csrmv):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs, platform='gpu')
-
-
-class Test__csrmv_GPU(test_csrmv.Test_csrmv):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs, platform='gpu')
-
-
diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py
new file mode 100644
index 000000000..b73217496
--- /dev/null
+++ b/brainpy/_src/math/sparse/tests/test_csrmv_old.py
@@ -0,0 +1,352 @@
+# -*- coding: utf-8 -*-
+
+from functools import partial
+
+import jax
+import pytest
+from absl.testing import parameterized
+import platform
+import brainpy as bp
+import brainpy.math as bm
+
+pytest.skip('Old implementation.', allow_module_level=True)
+
+is_manual_test = False
+# if platform.system() == 'Windows' and not is_manual_test:
+# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
+
+cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse')
+scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar')
+vector_csr_matvec = partial(bm.sparse.csrmv, method='vector')
+
+
+class Test_cusparse_csrmv(parameterized.TestCase):
+ def __init__(self, *args, platform='cpu', **kwargs):
+ super(Test_cusparse_csrmv, self).__init__(*args, **kwargs)
+
+ print()
+ bm.set_platform(platform)
+
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)],
+ homo_data=[-1., 0., 1.]
+ )
+ def test_homo(self, transpose, shape, homo_data):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+
+ heter_data = bm.ones(indices.shape).value * homo_data
+
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+ r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose)
+ r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ r3 = (vector @ dense) if transpose else (dense @ vector)
+ self.assertTrue(bm.allclose(r1, r3))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)],
+ v=[-1., 0., 1.]
+ )
+ def test_homo_vmap(self, transpose, shape, v):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+
+ heter_data = bm.ones((10, indices.shape[0])).value * v
+ homo_data = bm.ones(10).value * v
+ dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data)
+
+ f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector,
+ shape=shape, transpose=transpose)
+ f2 = lambda a: (a.T @ vector) if transpose else (a @ vector)
+
+ r1 = jax.vmap(f1)(homo_data)
+ r2 = jax.vmap(f1)(heter_data)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ r3 = jax.vmap(f2)(dense_data)
+ self.assertTrue(bm.allclose(r1, r3))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)],
+ homo_data=[-1., 0., 1.]
+ )
+ def test_homo_grad(self, transpose, shape, homo_data):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value,
+ indices,
+ indptr,
+ shape=shape)
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+
+ csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector,
+ shape=shape, transpose=transpose).sum(),
+ argnums=0)
+ dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum()
+ if transpose else
+ ((dense * a) @ vector).sum()),
+ argnums=0)
+
+ r1 = csr_f1(homo_data)
+ r2 = dense_f1(homo_data)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v,
+ shape=shape, transpose=transpose).sum())
+ dense_data = dense * homo_data
+ dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()))
+
+ r3 = csr_f2(vector)
+ r4 = dense_f2(vector)
+ self.assertTrue(bm.allclose(r3, r4))
+
+ csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v,
+ shape=shape, transpose=transpose).sum(),
+ argnums=(0, 1))
+ dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum()
+ if transpose else
+ ((dense * a) @ v).sum()),
+ argnums=(0, 1))
+
+ r5 = csr_f3(homo_data, vector)
+ r6 = dense_f3(homo_data, vector)
+ self.assertTrue(bm.allclose(r5[0], r6[0]))
+ self.assertTrue(bm.allclose(r5[1], r6[1]))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)],
+ )
+ def test_heter(self, transpose, shape):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+
+ heter_data = rng.random(indices.shape)
+ heter_data = bm.as_jax(heter_data)
+
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+ r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector,
+ shape=shape, transpose=transpose)
+ dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ r2 = (vector @ dense) if transpose else (dense @ vector)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
+ )
+ def test_heter_vmap(self, transpose, shape):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+
+ heter_data = rng.random((10, indices.shape[0]))
+ heter_data = bm.as_jax(heter_data)
+ dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr,
+ shape=shape))(heter_data)
+
+ f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector,
+ shape=shape, transpose=transpose)
+ f2 = lambda a: (a.T @ vector) if transpose else (a @ vector)
+
+ r1 = jax.vmap(f1)(heter_data)
+ r2 = jax.vmap(f2)(dense_data)
+ self.assertTrue(bm.allclose(r1, r2))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ transpose=[True, False],
+ shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
+ )
+ def test_heter_grad(self, transpose, shape):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ heter_data = rng.random(indices.shape)
+ heter_data = bm.as_jax(heter_data)
+ dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+
+ csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector,
+ shape=shape,
+ transpose=transpose).sum(),
+ argnums=0)
+ dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()),
+ argnums=0)
+
+ r1 = csr_f1(heter_data)
+ r2 = dense_f1(dense_data)
+ rows, cols = bm.sparse.csr_to_coo(indices, indptr)
+ r2 = r2[rows, cols]
+ self.assertTrue(bm.allclose(r1, r2))
+
+ csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v,
+ shape=shape,
+ transpose=transpose).sum(),
+ argnums=0)
+ dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()),
+ argnums=0)
+ r3 = csr_f2(vector)
+ r4 = dense_f2(vector)
+ self.assertTrue(bm.allclose(r3, r4))
+
+ bm.clear_buffer_memory()
+
+
+class Test_csrmv(parameterized.TestCase):
+ def __init__(self, *args, platform='cpu', **kwargs):
+ super(Test_csrmv, self).__init__(*args, **kwargs)
+
+ print()
+ bm.set_platform(platform)
+
+ @parameterized.product(
+ homo_data=[-1., 0., 0.1, 1.],
+ shape=[(100, 200), (10, 1000), (2, 2000)],
+ )
+ def test_homo(self, shape, homo_data):
+ conn = bp.conn.FixedProb(0.1)
+
+ # matrix
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ # vector
+ rng = bm.random.RandomState(123)
+ vector = rng.random(shape[1])
+ vector = bm.as_jax(vector)
+
+ # csrmv
+ r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape)
+ r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape)
+ r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape)
+ self.assertTrue(bm.allclose(r1, r2))
+ self.assertTrue(bm.allclose(r1, r3))
+
+ heter_data = bm.ones(indices.shape).to_jax() * homo_data
+ r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
+ r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
+ r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
+ self.assertTrue(bm.allclose(r1, r4))
+ self.assertTrue(bm.allclose(r1, r5))
+ self.assertTrue(bm.allclose(r1, r6))
+
+ dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ rdense = dense @ vector
+ self.assertTrue(bm.allclose(r1, rdense))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)]
+ )
+ def test_heter(self, shape):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ heter_data = bm.as_jax(rng.random(indices.shape))
+ vector = bm.as_jax(rng.random(shape[1]))
+
+ r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
+ r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
+ r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
+
+ dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ r4 = dense @ vector
+ self.assertTrue(bm.allclose(r1, r2))
+ self.assertTrue(bm.allclose(r1, r3))
+ self.assertTrue(bm.allclose(r1, r4))
+
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
+ )
+ def test_heter_grad(self, shape):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+
+ indices, indptr = conn(*shape).require('pre2post')
+ heter_data = rng.random(indices.shape)
+ dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
+ vector = rng.random(shape[1])
+
+ csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum())
+ csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum())
+ csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum())
+ dense_f1 = jax.grad(lambda a: (a @ vector).sum())
+
+ r1 = csr_f1(heter_data)
+ r2 = csr_f2(heter_data)
+ r3 = csr_f3(heter_data)
+
+ d1 = dense_f1(dense_data)
+ rows, cols = bm.sparse.csr_to_coo(indices, indptr)
+ d1 = d1[rows, cols]
+ self.assertTrue(bm.allclose(r1, r2))
+ self.assertTrue(bm.allclose(r1, r3))
+ self.assertTrue(bm.allclose(r1, d1))
+
+ # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum())
+ # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum())
+ # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum())
+ # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum())
+ # r4 = csr_f4(vector)
+ # r5 = csr_f5(vector)
+ # r6 = csr_f6(vector)
+ # d2 = dense_f2(vector)
+ # self.assertTrue(bm.allclose(r4, r5))
+ # self.assertTrue(bm.allclose(r4, r6))
+ # self.assertTrue(bm.allclose(r4, d2))
+
+ bm.clear_buffer_memory()
+
+
diff --git a/brainpy/_src/math/surrogate/__init__.py b/brainpy/_src/math/surrogate/__init__.py
index 2ad7ac54e..199eac648 100644
--- a/brainpy/_src/math/surrogate/__init__.py
+++ b/brainpy/_src/math/surrogate/__init__.py
@@ -2,5 +2,5 @@
from .base import *
-from ._one_input import *
+from ._one_input_new import *
from ._two_inputs import *
diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py
new file mode 100644
index 000000000..64c7280d0
--- /dev/null
+++ b/brainpy/_src/math/surrogate/_one_input_new.py
@@ -0,0 +1,1757 @@
+# -*- coding: utf-8 -*-
+
+from typing import Union
+
+import jax
+import jax.numpy as jnp
+import jax.scipy as sci
+from jax.core import Primitive
+from jax.interpreters import batching, ad, mlir
+
+from brainpy._src.math.interoperability import as_jax
+from brainpy._src.math.ndarray import Array
+
+__all__ = [
+ 'Sigmoid',
+ 'sigmoid',
+ 'PiecewiseQuadratic',
+ 'piecewise_quadratic',
+ 'PiecewiseExp',
+ 'piecewise_exp',
+ 'SoftSign',
+ 'soft_sign',
+ 'Arctan',
+ 'arctan',
+ 'NonzeroSignLog',
+ 'nonzero_sign_log',
+ 'ERF',
+ 'erf',
+ 'PiecewiseLeakyRelu',
+ 'piecewise_leaky_relu',
+ 'SquarewaveFourierSeries',
+ 'squarewave_fourier_series',
+ 'S2NN',
+ 's2nn',
+ 'QPseudoSpike',
+ 'q_pseudo_spike',
+ 'LeakyRelu',
+ 'leaky_relu',
+ 'LogTailedRelu',
+ 'log_tailed_relu',
+ 'ReluGrad',
+ 'relu_grad',
+ 'GaussianGrad',
+ 'gaussian_grad',
+ 'InvSquareGrad',
+ 'inv_square_grad',
+ 'MultiGaussianGrad',
+ 'multi_gaussian_grad',
+ 'SlayerGrad',
+ 'slayer_grad',
+]
+
+
+def _heaviside_abstract(x, dx):
+ return [x]
+
+
+def _heaviside_imp(x, dx):
+ z = jnp.asarray(x >= 0, dtype=x.dtype)
+ return [z]
+
+
+def _heaviside_batching(args, axes):
+ return heaviside_p.bind(*args), axes
+
+
+def _heaviside_jvp(primals, tangents):
+ x, dx = primals
+ tx, tdx = tangents
+ primal_outs = heaviside_p.bind(x, dx)
+ tangent_outs = [dx * tx, ]
+ return primal_outs, tangent_outs
+
+
+heaviside_p = Primitive('heaviside_p')
+heaviside_p.multiple_results = True
+heaviside_p.def_abstract_eval(_heaviside_abstract)
+heaviside_p.def_impl(_heaviside_imp)
+batching.primitive_batchers[heaviside_p] = _heaviside_batching
+ad.primitive_jvps[heaviside_p] = _heaviside_jvp
+mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
+
+
+def _is_bp_array(x):
+ return isinstance(x, Array)
+
+
+def _as_jax(x):
+ return x.value if _is_bp_array(x) else x
+
+
+class Surrogate(object):
+ """The base surrograte gradient function."""
+
+ def __call__(self, x):
+ x = _as_jax(x)
+ dx = self.surrogate_grad(x)
+ return heaviside_p.bind(x, dx)[0]
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
+
+ def surrogate_fun(self, x) -> jax.Array:
+ """The surrogate function."""
+ raise NotImplementedError
+
+ def surrogate_grad(self, x) -> jax.Array:
+ """The gradient function of the surrogate function."""
+ raise NotImplementedError
+
+
+class Sigmoid(Surrogate):
+ """Spike function with the sigmoid-shaped surrogate gradient.
+
+ See Also
+ --------
+ sigmoid
+
+ """
+
+ def __init__(self, alpha: float = 4.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_fun(self, x):
+ return sci.special.expit(x)
+
+ def surrogate_grad(self, x):
+ sgax = sci.special.expit(x * self.alpha)
+ dx = (1. - sgax) * sgax * self.alpha
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def sigmoid(
+ x: Union[jax.Array, Array],
+ alpha: float = 4.,
+):
+ r"""Spike function with the sigmoid-shaped surrogate gradient.
+
+ If `origin=False`, return the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-2, 2, 1000)
+ >>> for alpha in [1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.sigmoid)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+ """
+ return Sigmoid(alpha=alpha)(x)
+
+
+class PiecewiseQuadratic(Surrogate):
+ """Judge spiking state with a piecewise quadratic function.
+
+ See Also
+ --------
+ piecewise_quadratic
+
+ """
+
+ def __init__(self, alpha: float = 1.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ z = jnp.where(x < -1 / self.alpha,
+ 0.,
+ jnp.where(x > 1 / self.alpha,
+ 1.,
+ (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5))
+ return z
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def piecewise_quadratic(
+ x: Union[jax.Array, Array],
+ alpha: float = 1.,
+):
+ r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ g(x) =
+ \begin{cases}
+ 0, & x < -\frac{1}{\alpha} \\
+ -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
+ 1, & x > \frac{1}{\alpha} \\
+ \end{cases}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) =
+ \begin{cases}
+ 0, & |x| > \frac{1}{\alpha} \\
+ -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
+ \end{cases}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.piecewise_quadratic)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
+ .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
+ .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
+ .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
+ .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
+ """
+ return PiecewiseQuadratic(alpha=alpha)(x)
+
+
+class PiecewiseExp(Surrogate):
+ """Judge spiking state with a piecewise exponential function.
+
+ See Also
+ --------
+ piecewise_exp
+ """
+
+ def __init__(self, alpha: float = 1.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
+ return dx
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2)
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def piecewise_exp(
+ x: Union[jax.Array, Array],
+ alpha: float = 1.,
+
+):
+ r"""Judge spiking state with a piecewise exponential function [1]_.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ \frac{1}{2}e^{\alpha x}, & x < 0 \\
+ 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
+ \end{cases}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.piecewise_exp)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
+ """
+ return PiecewiseExp(alpha=alpha)(x)
+
+
+class SoftSign(Surrogate):
+ """Judge spiking state with a soft sign function.
+
+ See Also
+ --------
+ soft_sign
+ """
+
+ def __init__(self, alpha=1.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
+ return dx
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def soft_sign(
+ x: Union[jax.Array, Array],
+ alpha: float = 1.,
+
+):
+ r"""Judge spiking state with a soft sign function.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
+ = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.soft_sign)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ """
+ return SoftSign(alpha=alpha)(x)
+
+
+class Arctan(Surrogate):
+ """Judge spiking state with an arctan function.
+
+ See Also
+ --------
+ arctan
+ """
+
+ def __init__(self, alpha=1.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
+ return dx
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def arctan(
+ x: Union[jax.Array, Array],
+ alpha: float = 1.,
+
+):
+ r"""Judge spiking state with an arctan function.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.arctan)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ """
+ return Arctan(alpha=alpha)(x)
+
+
+class NonzeroSignLog(Surrogate):
+ """Judge spiking state with a nonzero sign log function.
+
+ See Also
+ --------
+ nonzero_sign_log
+ """
+
+ def __init__(self, alpha=1.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = 1. / (1 / self.alpha + jnp.abs(x))
+ return dx
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def nonzero_sign_log(
+ x: Union[jax.Array, Array],
+ alpha: float = 1.,
+
+):
+ r"""Judge spiking state with a nonzero sign log function.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
+
+ where
+
+ .. math::
+
+ \begin{split}\mathrm{NonzeroSign}(x) =
+ \begin{cases}
+ 1, & x \geq 0 \\
+ -1, & x < 0 \\
+ \end{cases}\end{split}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
+
+ This surrogate function has the advantage of low computation cost during the backward.
+
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ """
+ return NonzeroSignLog(alpha=alpha)(x)
+
+
+class ERF(Surrogate):
+ """Judge spiking state with an erf function.
+
+ See Also
+ --------
+ erf
+ """
+
+ def __init__(self, alpha=1.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
+ return dx
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ return sci.special.erf(-self.alpha * x) * 0.5
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def erf(
+ x: Union[jax.Array, Array],
+ alpha: float = 1.,
+
+):
+ r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ \begin{split}
+ g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
+ &= \frac{1}{2} \text{erfc}(-\alpha x) \\
+ &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
+ \end{split}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
+ .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
+ .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
+
+ """
+ return ERF(alpha=alpha)(x)
+
+
+class PiecewiseLeakyRelu(Surrogate):
+ """Judge spiking state with a piecewise leaky relu function.
+
+ See Also
+ --------
+ piecewise_leaky_relu
+ """
+
+ def __init__(self, c=0.01, w=1.):
+ super().__init__()
+ self.c = c
+ self.w = w
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ z = jnp.where(x < -self.w,
+ self.c * x + self.c * self.w,
+ jnp.where(x > self.w,
+ self.c * x - self.c * self.w + 1,
+ 0.5 * x / self.w + 0.5))
+ return z
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
+
+
+def piecewise_leaky_relu(
+ x: Union[jax.Array, Array],
+ c: float = 0.01,
+ w: float = 1.,
+
+):
+ r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ \begin{split}g(x) =
+ \begin{cases}
+ cx + cw, & x < -w \\
+ \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
+ cx - cw + 1, & x > w \\
+ \end{cases}\end{split}
+
+ Backward function:
+
+ .. math::
+
+ \begin{split}g'(x) =
+ \begin{cases}
+ \frac{1}{w}, & |x| \leq w \\
+ c, & |x| > w
+ \end{cases}\end{split}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for c in [0.01, 0.05, 0.1]:
+ >>> for w in [1., 2.]:
+ >>> grads1 = bm.vector_grad(bm.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'x={c}, w={w}')
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ c: float
+ When :math:`|x| > w` the gradient is `c`.
+ w: float
+ When :math:`|x| <= w` the gradient is `1 / w`.
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
+ .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
+ .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
+ .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
+ .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
+ .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
+ .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
+ .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
+
+ """
+ return PiecewiseLeakyRelu(c=c, w=w)(x)
+
+
+class SquarewaveFourierSeries(Surrogate):
+ """Judge spiking state with a squarewave fourier series.
+
+ See Also
+ --------
+ squarewave_fourier_series
+ """
+
+ def __init__(self, n=2, t_period=8.):
+ super().__init__()
+ self.n = n
+ self.t_period = t_period
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ w = jnp.pi * 2. / self.t_period
+ dx = jnp.cos(w * x)
+ for i in range(2, self.n):
+ dx += jnp.cos((2 * i - 1.) * w * x)
+ dx *= 4. / self.t_period
+ return dx
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ w = jnp.pi * 2. / self.t_period
+ ret = jnp.sin(w * x)
+ for i in range(2, self.n):
+ c = (2 * i - 1.)
+ ret += jnp.sin(c * w * x) / c
+ z = 0.5 + 2. / jnp.pi * ret
+ return z
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
+
+
+def squarewave_fourier_series(
+ x: Union[jax.Array, Array],
+ n: int = 2,
+ t_period: float = 8.,
+
+):
+ r"""Judge spiking state with a squarewave fourier series.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for n in [2, 4, 8]:
+ >>> f = bm.surrogate.SquarewaveFourierSeries(n=n)
+ >>> grads1 = bm.vector_grad(f)(xs)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'n={n}')
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ n: int
+ t_period: float
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ """
+
+ return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
+
+
+class S2NN(Surrogate):
+ """Judge spiking state with the S2NN surrogate spiking function.
+
+ See Also
+ --------
+ s2nn
+ """
+
+ def __init__(self, alpha=4., beta=1., epsilon=1e-8):
+ super().__init__()
+ self.alpha = alpha
+ self.beta = beta
+ self.epsilon = epsilon
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ z = jnp.where(x < 0.,
+ sci.special.expit(x * self.alpha),
+ self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
+ return z
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ sg = sci.special.expit(self.alpha * x)
+ dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
+
+
+def s2nn(
+ x: Union[jax.Array, Array],
+ alpha: float = 4.,
+ beta: float = 1.,
+ epsilon: float = 1e-8,
+
+):
+ r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ \begin{split}g(x) = \begin{cases}
+ \mathrm{sigmoid} (\alpha x), x < 0 \\
+ \beta \ln(|x + 1|) + 0.5, x \ge 0
+ \end{cases}\end{split}
+
+ Backward function:
+
+ .. math::
+
+ \begin{split}g'(x) = \begin{cases}
+ \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
+ \frac{\beta}{(x + 1)}, x \ge 0
+ \end{cases}\end{split}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 4., 1.)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=4, \beta=1$')
+ >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 8., 2.)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=8, \beta=2$')
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ The param that controls the gradient when ``x < 0``.
+ beta: float
+ The param that controls the gradient when ``x >= 0``
+ epsilon: float
+ Avoid nan
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
+
+ """
+ return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
+
+
+class QPseudoSpike(Surrogate):
+ """Judge spiking state with the q-PseudoSpike surrogate function.
+
+ See Also
+ --------
+ q_pseudo_spike
+ """
+
+ def __init__(self, alpha=2.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
+ return dx
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ z = jnp.where(x < 0.,
+ 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
+ 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
+ return z
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def q_pseudo_spike(
+ x: Union[jax.Array, Array],
+ alpha: float = 2.,
+
+):
+ r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ \begin{split}g(x) =
+ \begin{cases}
+ \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
+ 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
+ \end{cases}\end{split}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.q_pseudo_spike)(xs, alpha)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ The parameter to control tail fatness of gradient.
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
+ """
+ return QPseudoSpike(alpha=alpha)(x)
+
+
+class LeakyRelu(Surrogate):
+ """Judge spiking state with the Leaky ReLU function.
+
+ See Also
+ --------
+ leaky_relu
+ """
+
+ def __init__(self, alpha=0.1, beta=1.):
+ super().__init__()
+ self.alpha = alpha
+ self.beta = beta
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ return jnp.where(x < 0., self.alpha * x, self.beta * x)
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = jnp.where(x < 0., self.alpha, self.beta)
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
+
+
+def leaky_relu(
+ x: Union[jax.Array, Array],
+ alpha: float = 0.1,
+ beta: float = 1.,
+
+):
+ r"""Judge spiking state with the Leaky ReLU function.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ \begin{split}g(x) =
+ \begin{cases}
+ \beta \cdot x, & x \geq 0 \\
+ \alpha \cdot x, & x < 0 \\
+ \end{cases}\end{split}
+
+ Backward function:
+
+ .. math::
+
+ \begin{split}g'(x) =
+ \begin{cases}
+ \beta, & x \geq 0 \\
+ \alpha, & x < 0 \\
+ \end{cases}\end{split}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$')
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ The parameter to control the gradient when :math:`x < 0`.
+ beta: float
+ The parameter to control the gradient when :math:`x >= 0`.
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+ """
+ return LeakyRelu(alpha=alpha, beta=beta)(x)
+
+
+class LogTailedRelu(Surrogate):
+ """Judge spiking state with the Log-tailed ReLU function.
+
+ See Also
+ --------
+ log_tailed_relu
+ """
+
+ def __init__(self, alpha=0.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_fun(self, x):
+ x = as_jax(x)
+ z = jnp.where(x > 1,
+ jnp.log(x),
+ jnp.where(x > 0,
+ x,
+ self.alpha * x))
+ return z
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = jnp.where(x > 1,
+ 1 / x,
+ jnp.where(x > 0,
+ 1.,
+ self.alpha))
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def log_tailed_relu(
+ x: Union[jax.Array, Array],
+ alpha: float = 0.,
+
+):
+ r"""Judge spiking state with the Log-tailed ReLU function [1]_.
+
+ If `origin=False`, computes the forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ If `origin=True`, computes the original function:
+
+ .. math::
+
+ \begin{split}g(x) =
+ \begin{cases}
+ \alpha x, & x \leq 0 \\
+ x, & 0 < x \leq 0 \\
+ log(x), x > 1 \\
+ \end{cases}\end{split}
+
+ Backward function:
+
+ .. math::
+
+ \begin{split}g'(x) =
+ \begin{cases}
+ \alpha, & x \leq 0 \\
+ 1, & 0 < x \leq 0 \\
+ \frac{1}{x}, x > 1 \\
+ \end{cases}\end{split}
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$')
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ The parameter to control the gradient.
+
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
+ """
+ return LogTailedRelu(alpha=alpha)(x)
+
+
+class ReluGrad(Surrogate):
+ """Judge spiking state with the ReLU gradient function.
+
+ See Also
+ --------
+ relu_grad
+ """
+
+ def __init__(self, alpha=0.3, width=1.):
+ super().__init__()
+ self.alpha = alpha
+ self.width = width
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
+
+
+def relu_grad(
+ x: Union[jax.Array, Array],
+ alpha: float = 0.3,
+ width: float = 1.,
+):
+ r"""Spike function with the ReLU gradient function [1]_.
+
+ The forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> for s in [0.5, 1.]:
+ >>> for w in [1, 2.]:
+ >>> grads = bm.vector_grad(bm.surrogate.relu_grad)(xs, s, w)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + f'{s}, width={w}')
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ The parameter to control the gradient.
+ width: float
+ The parameter to control the width of the gradient.
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
+ """
+ return ReluGrad(alpha=alpha, width=width)(x)
+
+
+class GaussianGrad(Surrogate):
+ """Judge spiking state with the Gaussian gradient function.
+
+ See Also
+ --------
+ gaussian_grad
+ """
+
+ def __init__(self, sigma=0.5, alpha=0.5):
+ super().__init__()
+ self.sigma = sigma
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
+ return self.alpha * dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
+
+
+def gaussian_grad(
+ x: Union[jax.Array, Array],
+ sigma: float = 0.5,
+ alpha: float = 0.5,
+):
+ r"""Spike function with the Gaussian gradient function [1]_.
+
+ The forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> for s in [0.5, 1., 2.]:
+ >>> grads = bm.vector_grad(bm.surrogate.gaussian_grad)(xs, s, 0.5)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0.5, \sigma=$' + str(s))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ sigma: float
+ The parameter to control the variance of gaussian distribution.
+ alpha: float
+ The parameter to control the scale of the gradient.
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
+ """
+ return GaussianGrad(sigma=sigma, alpha=alpha)(x)
+
+
+class MultiGaussianGrad(Surrogate):
+ """Judge spiking state with the multi-Gaussian gradient function.
+
+ See Also
+ --------
+ multi_gaussian_grad
+ """
+
+ def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
+ super().__init__()
+ self.h = h
+ self.s = s
+ self.sigma = sigma
+ self.scale = scale
+
+ def surrogate_grad(self, x):
+ x = as_jax(x)
+ g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
+ g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
+ ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
+ g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
+ ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
+ dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
+ return self.scale * dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
+
+
+def multi_gaussian_grad(
+ x: Union[jax.Array, Array],
+ h: float = 0.15,
+ s: float = 6.0,
+ sigma: float = 0.5,
+ scale: float = 0.5,
+):
+ r"""Spike function with the multi-Gaussian gradient function [1]_.
+
+ The forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ Backward function:
+
+ .. math::
+
+ \begin{array}{l}
+ g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
+ -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
+ h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
+ \end{array}
+
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> grads = bm.vector_grad(bm.surrogate.multi_gaussian_grad)(xs)
+ >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads))
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ h: float
+ The hyper-parameters of approximate function
+ s: float
+ The hyper-parameters of approximate function
+ sigma: float
+ The gaussian sigma.
+ scale: float
+ The gradient scale.
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
+ """
+ return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
+
+
+class InvSquareGrad(Surrogate):
+ """Judge spiking state with the inverse-square surrogate gradient function.
+
+ See Also
+ --------
+ inv_square_grad
+ """
+
+ def __init__(self, alpha=100.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def inv_square_grad(
+ x: Union[jax.Array, Array],
+ alpha: float = 100.
+):
+ r"""Spike function with the inverse-square surrogate gradient.
+
+ Forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
+
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> xs = bm.linspace(-1, 1, 1000)
+ >>> for alpha in [1., 10., 100.]:
+ >>> grads = bm.vector_grad(bm.surrogate.inv_square_grad)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+ """
+ return InvSquareGrad(alpha=alpha)(x)
+
+
+class SlayerGrad(Surrogate):
+ """Judge spiking state with the slayer surrogate gradient function.
+
+ See Also
+ --------
+ slayer_grad
+ """
+
+ def __init__(self, alpha=1.):
+ super().__init__()
+ self.alpha = alpha
+
+ def surrogate_grad(self, x):
+ dx = jnp.exp(-self.alpha * jnp.abs(x))
+ return dx
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
+
+
+def slayer_grad(
+ x: Union[jax.Array, Array],
+ alpha: float = 1.
+):
+ r"""Spike function with the slayer surrogate gradient function.
+
+ Forward function:
+
+ .. math::
+
+ g(x) = \begin{cases}
+ 1, & x \geq 0 \\
+ 0, & x < 0 \\
+ \end{cases}
+
+ Backward function:
+
+ .. math::
+
+ g'(x) = \exp(-\alpha |x|)
+
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>> import matplotlib.pyplot as plt
+ >>> bp.visualize.get_figure(1, 1, 4, 6)
+ >>> xs = bm.linspace(-3, 3, 1000)
+ >>> for alpha in [0.5, 1., 2., 4.]:
+ >>> grads = bm.vector_grad(bm.surrogate.slayer_grad)(xs, alpha)
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ x: jax.Array, Array
+ The input data.
+ alpha: float
+ Parameter to control smoothness of gradient
+
+ Returns
+ -------
+ out: jax.Array
+ The spiking state.
+
+ References
+ ----------
+ .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
+ """
+ return SlayerGrad(alpha=alpha)(x)
diff --git a/brainpy/_src/math/tests/test_defaults.py b/brainpy/_src/math/tests/test_defaults.py
new file mode 100644
index 000000000..9076829b7
--- /dev/null
+++ b/brainpy/_src/math/tests/test_defaults.py
@@ -0,0 +1,36 @@
+import unittest
+
+import brainpy.math as bm
+
+
+class TestDefaults(unittest.TestCase):
+ def test_dt(self):
+ with bm.environment(dt=1.0):
+ self.assertEqual(bm.dt, 1.0)
+ self.assertEqual(bm.get_dt(), 1.0)
+
+ def test_bool(self):
+ with bm.environment(bool_=bm.int32):
+ self.assertTrue(bm.bool_ == bm.int32)
+ self.assertTrue(bm.get_bool() == bm.int32)
+
+ def test_int(self):
+ with bm.environment(int_=bm.int32):
+ self.assertTrue(bm.int == bm.int32)
+ self.assertTrue(bm.get_int() == bm.int32)
+
+ def test_float(self):
+ with bm.environment(float_=bm.float32):
+ self.assertTrue(bm.float_ == bm.float32)
+ self.assertTrue(bm.get_float() == bm.float32)
+
+ def test_complex(self):
+ with bm.environment(complex_=bm.complex64):
+ self.assertTrue(bm.complex_ == bm.complex64)
+ self.assertTrue(bm.get_complex() == bm.complex64)
+
+ def test_mode(self):
+ mode = bm.TrainingMode()
+ with bm.environment(mode=mode):
+ self.assertTrue(bm.mode == mode)
+ self.assertTrue(bm.get_mode() == mode)
diff --git a/brainpy/_src/math/tests/test_einops.py b/brainpy/_src/math/tests/test_einops.py
new file mode 100644
index 000000000..2f018d973
--- /dev/null
+++ b/brainpy/_src/math/tests/test_einops.py
@@ -0,0 +1,331 @@
+import numpy
+import pytest
+
+import brainpy.math as bm
+from brainpy._src.math.einops import ein_rearrange, ein_reduce, ein_repeat, _enumerate_directions
+from brainpy._src.math.einops_parsing import EinopsError
+
+REDUCTIONS = ("min", "max", "sum", "mean", "prod")
+
+identity_patterns = [
+ "...->...",
+ "a b c d e-> a b c d e",
+ "a b c d e ...-> ... a b c d e",
+ "a b c d e ...-> a ... b c d e",
+ "... a b c d e -> ... a b c d e",
+ "a ... e-> a ... e",
+ "a ... -> a ... ",
+ "a ... c d e -> a (...) c d e",
+]
+
+equivalent_rearrange_patterns = [
+ ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
+ ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
+ ("a b c d e -> a b c d e", "... -> ... "),
+ ("a b c d e -> (a b c d e)", "... -> (...)"),
+ ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"),
+ ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"),
+]
+
+equivalent_reduction_patterns = [
+ ("a b c d e -> ", " ... -> "),
+ ("a b c d e -> (e a)", "a ... e -> (e a)"),
+ ("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "),
+ ("a b c d e -> (a b)", " ... c d e -> (...) "),
+]
+
+
+def test_collapsed_ellipsis_errors_out():
+ x = numpy.zeros([1, 1, 1, 1, 1])
+ ein_rearrange(x, "a b c d ... -> a b c ... d")
+ with pytest.raises(EinopsError):
+ ein_rearrange(x, "a b c d (...) -> a b c ... d")
+
+ ein_rearrange(x, "... -> (...)")
+ with pytest.raises(EinopsError):
+ ein_rearrange(x, "(...) -> (...)")
+
+
+def test_ellipsis_ops_numpy():
+ x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
+ for pattern in identity_patterns:
+ assert numpy.array_equal(x, ein_rearrange(x, pattern)), pattern
+
+ for pattern1, pattern2 in equivalent_rearrange_patterns:
+ assert numpy.array_equal(ein_rearrange(x, pattern1), ein_rearrange(x, pattern2))
+
+ for reduction in ["min", "max", "sum"]:
+ for pattern1, pattern2 in equivalent_reduction_patterns:
+ assert numpy.array_equal(ein_reduce(x, pattern1, reduction=reduction),
+ ein_reduce(x, pattern2, reduction=reduction))
+
+ # now just check coincidence with numpy
+ all_rearrange_patterns = [*identity_patterns]
+ for pattern_pairs in equivalent_rearrange_patterns:
+ all_rearrange_patterns.extend(pattern_pairs)
+
+
+def test_rearrange_consistency_numpy():
+ shape = [1, 2, 3, 5, 7, 11]
+ x = numpy.arange(numpy.prod(shape)).reshape(shape)
+ for pattern in [
+ "a b c d e f -> a b c d e f",
+ "b a c d e f -> a b d e f c",
+ "a b c d e f -> f e d c b a",
+ "a b c d e f -> (f e) d (c b a)",
+ "a b c d e f -> (f e d c b a)",
+ ]:
+ result = ein_rearrange(x, pattern)
+ assert len(numpy.setdiff1d(x, result)) == 0
+
+ result = ein_rearrange(x, "a b c d e f -> a (b) (c d e) f")
+ assert numpy.array_equal(x.flatten(), result.flatten())
+
+ result = ein_rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11")
+ assert numpy.array_equal(x, result)
+
+ result1 = ein_rearrange(x, "a b c d e f -> f e d c b a")
+ result2 = ein_rearrange(x, "f e d c b a -> a b c d e f")
+ assert numpy.array_equal(result1, result2)
+
+ result = ein_rearrange(ein_rearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, d=5)
+ assert numpy.array_equal(x, result)
+
+ sizes = dict(zip("abcdef", shape))
+ temp = ein_rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes)
+ result = ein_rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes)
+ assert numpy.array_equal(x, result)
+
+ x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4])
+ result = ein_rearrange(x2, "a b c -> b c a")
+ assert x2[1, 2, 3] == result[2, 3, 1]
+ assert x2[0, 1, 2] == result[1, 2, 0]
+
+
+def test_rearrange_permutations_numpy():
+ # tests random permutation of axes against two independent numpy ways
+ for n_axes in range(1, 10):
+ input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
+ permutation = numpy.random.permutation(n_axes)
+ left_expression = " ".join("i" + str(axis) for axis in range(n_axes))
+ right_expression = " ".join("i" + str(axis) for axis in permutation)
+ expression = left_expression + " -> " + right_expression
+ result = ein_rearrange(input, expression)
+
+ for pick in numpy.random.randint(0, 2, [10, n_axes]):
+ assert input[tuple(pick)] == result[tuple(pick[permutation])]
+
+ for n_axes in range(1, 10):
+ input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
+ permutation = numpy.random.permutation(n_axes)
+ left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1])
+ right_expression = " ".join("i" + str(axis) for axis in permutation[::-1])
+ expression = left_expression + " -> " + right_expression
+ result = ein_rearrange(input, expression)
+ assert result.shape == input.shape
+ expected_result = numpy.zeros_like(input)
+ for original_axis, result_axis in enumerate(permutation):
+ expected_result |= ((input >> original_axis) & 1) << result_axis
+
+ assert numpy.array_equal(result, expected_result)
+
+
+def test_reduction_imperatives():
+ for reduction in REDUCTIONS:
+ # slight redundancy for simpler order - numpy version is evaluated multiple times
+ input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6])
+ if reduction in ["mean", "prod"]:
+ input = input / input.astype("float64").mean()
+ test_cases = [
+ ["a b c d e -> ", {}, getattr(input, reduction)()],
+ ["a ... -> ", {}, getattr(input, reduction)()],
+ ["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()],
+ [
+ "a b c d e -> (e c) a",
+ {},
+ getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
+ ],
+ [
+ "a ... c d e -> (e c) a",
+ {},
+ getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
+ ],
+ [
+ "a b c d e ... -> (e c) a",
+ {},
+ getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
+ ],
+ ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
+ ["(a a2) ... -> (a2 a) ...", dict(a2=1), input],
+ ]
+ for pattern, axes_lengths, expected_result in test_cases:
+ result = ein_reduce(bm.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
+ result = bm.as_numpy(result)
+ print(reduction, pattern, expected_result, result)
+ assert numpy.allclose(result, expected_result), f"Failed at {pattern}"
+
+
+def test_enumerating_directions():
+ for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
+ x = numpy.arange(numpy.prod(shape)).reshape(shape)
+ axes1 = _enumerate_directions(x)
+ axes2 = _enumerate_directions(bm.from_numpy(x))
+ assert len(axes1) == len(axes2) == len(shape)
+ for ax1, ax2 in zip(axes1, axes2):
+ ax2 = bm.as_numpy(ax2)
+ assert ax1.shape == ax2.shape
+ assert numpy.allclose(ax1, ax2)
+
+
+def test_concatenations_and_stacking():
+ for n_arrays in [1, 2, 5]:
+ shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
+ for shape in shapes:
+ arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)]
+ arrays2 = [bm.from_numpy(array) for array in arrays1]
+ result0 = numpy.asarray(arrays1)
+ result1 = ein_rearrange(arrays1, "...->...")
+ result2 = ein_rearrange(arrays2, "...->...")
+ assert numpy.array_equal(result0, result1)
+ assert numpy.array_equal(result1, bm.as_numpy(result2))
+
+ result1 = ein_rearrange(arrays1, "b ... -> ... b")
+ result2 = ein_rearrange(arrays2, "b ... -> ... b")
+ assert numpy.array_equal(result1, bm.as_numpy(result2))
+
+
+def test_gradients_imperatives():
+ # lazy - just checking reductions
+ for reduction in REDUCTIONS:
+ if reduction in ("any", "all"):
+ continue # non-differentiable ops
+ x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32")
+ y0 = bm.from_numpy(x)
+ if not hasattr(y0, "grad"):
+ continue
+
+ y1 = ein_reduce(y0, "a b c -> c a", reduction=reduction)
+ y2 = ein_reduce(y1, "c a -> a c", reduction=reduction)
+ y3 = ein_reduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2)
+ y4 = ein_reduce(y3, "... -> ", reduction=reduction)
+
+ y4.backward()
+ grad = bm.as_numpy(y0.grad)
+
+
+def test_tiling_imperatives():
+ input = numpy.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5])
+ test_cases = [
+ (1, 1, 1, 1, 1),
+ (1, 2, 1, 3, 1),
+ (3, 1, 1, 4, 1),
+ ]
+ for repeats in test_cases:
+ expected = numpy.tile(input, repeats)
+ converted = bm.from_numpy(input)
+ repeated = bm.tile(converted, repeats)
+ result = bm.as_numpy(repeated)
+ assert numpy.array_equal(result, expected)
+
+
+repeat_test_cases = [
+ # all assume that input has shape [2, 3, 5]
+ ("a b c -> c a b", dict()),
+ ("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)),
+ ("a b c -> (a copy) b c ", dict(copy=1)),
+ ("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)),
+ ("a ... -> a ... copy", dict(copy=4)),
+ ("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)),
+ ("... -> ... ", dict()),
+ (" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)),
+ ("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)),
+]
+
+
+def check_reversion(x, repeat_pattern, **sizes):
+ """Checks repeat pattern by running reduction"""
+ left, right = repeat_pattern.split("->")
+ reduce_pattern = right + "->" + left
+ repeated = ein_repeat(x, repeat_pattern, **sizes)
+ reduced_min = ein_reduce(repeated, reduce_pattern, reduction="min", **sizes)
+ reduced_max = ein_reduce(repeated, reduce_pattern, reduction="max", **sizes)
+ assert numpy.array_equal(x, reduced_min)
+ assert numpy.array_equal(x, reduced_max)
+
+
+def test_repeat_numpy():
+ # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
+ x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
+ x1 = ein_repeat(x, "a b c -> copy a b c ", copy=1)
+ assert numpy.array_equal(x[None], x1)
+ for pattern, axis_dimensions in repeat_test_cases:
+ check_reversion(x, pattern, **axis_dimensions)
+
+
+test_cases_repeat_anonymous = [
+ # all assume that input has shape [1, 2, 4, 6]
+ ("a b c d -> c a d b", dict()),
+ ("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)),
+ ("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)),
+ ("1 ... -> 3 ... ", dict()),
+ ("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)),
+ ("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()),
+]
+
+
+def test_anonymous_axes():
+ x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
+ for pattern, axis_dimensions in test_cases_repeat_anonymous:
+ check_reversion(x, pattern, **axis_dimensions)
+
+
+def test_list_inputs():
+ x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
+
+ assert numpy.array_equal(
+ ein_rearrange(list(x), "... -> (...)"),
+ ein_rearrange(x, "... -> (...)"),
+ )
+ assert numpy.array_equal(
+ ein_reduce(list(x), "a ... e -> (...)", "min"),
+ ein_reduce(x, "a ... e -> (...)", "min"),
+ )
+ assert numpy.array_equal(
+ ein_repeat(list(x), "... -> b (...)", b=3),
+ ein_repeat(x, "... -> b (...)", b=3),
+ )
+
+
+def bit_count(x):
+ return sum((x >> i) & 1 for i in range(20))
+
+
+def test_reduction_imperatives_booleans():
+ """Checks that any/all reduction works in all frameworks"""
+ x_np = numpy.asarray([(bit_count(x) % 2) == 0 for x in range(2 ** 6)]).reshape([2] * 6)
+
+ for axis in range(6):
+ expected_result_any = numpy.any(x_np, axis=axis, keepdims=True)
+ expected_result_all = numpy.all(x_np, axis=axis, keepdims=True)
+ assert not numpy.array_equal(expected_result_any, expected_result_all)
+
+ axes = list("abcdef")
+ axes_in = list(axes)
+ axes_out = list(axes)
+ axes_out[axis] = "1"
+ pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out))
+
+ res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any")
+ res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all")
+
+ assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any))
+ assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all))
+
+ # expected result: any/all
+ expected_result_any = numpy.any(x_np, axis=(0, 1), keepdims=True)
+ expected_result_all = numpy.all(x_np, axis=(0, 1), keepdims=True)
+ pattern = "a b ... -> 1 1 ..."
+ res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any")
+ res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all")
+ assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any))
+ assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all))
diff --git a/brainpy/_src/math/tests/test_einops_parsing.py b/brainpy/_src/math/tests/test_einops_parsing.py
new file mode 100644
index 000000000..069c7bbac
--- /dev/null
+++ b/brainpy/_src/math/tests/test_einops_parsing.py
@@ -0,0 +1,111 @@
+import pytest
+
+from brainpy._src.math.einops_parsing import EinopsError, ParsedExpression, AnonymousAxis, _ellipsis
+
+
+class AnonymousAxisPlaceholder:
+ def __init__(self, value: int):
+ self.value = value
+ assert isinstance(self.value, int)
+
+ def __eq__(self, other):
+ return isinstance(other, AnonymousAxis) and self.value == other.value
+
+
+def test_anonymous_axes():
+ a, b = AnonymousAxis('2'), AnonymousAxis('2')
+ assert a != b
+ c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3)
+ assert a == c and b == c
+ assert a != d and b != d
+ assert [a, 2, b] == [c, 2, c]
+
+
+def test_elementary_axis_name():
+ for name in ['a', 'b', 'h', 'dx', 'h1', 'zz', 'i9123', 'somelongname',
+ 'Alex', 'camelCase', 'u_n_d_e_r_score', 'unreasonablyLongAxisName']:
+ assert ParsedExpression.check_axis_name(name)
+
+ for name in ['', '2b', '12', '_startWithUnderscore', 'endWithUnderscore_', '_', '...', _ellipsis]:
+ assert not ParsedExpression.check_axis_name(name)
+
+
+def test_invalid_expressions():
+ # double ellipsis should raise an error
+ ParsedExpression('... a b c d')
+ with pytest.raises(EinopsError):
+ ParsedExpression('... a b c d ...')
+ with pytest.raises(EinopsError):
+ ParsedExpression('... a b c (d ...)')
+ with pytest.raises(EinopsError):
+ ParsedExpression('(... a) b c (d ...)')
+
+ # double/missing/enclosed parenthesis
+ ParsedExpression('(a) b c (d ...)')
+ with pytest.raises(EinopsError):
+ ParsedExpression('(a)) b c (d ...)')
+ with pytest.raises(EinopsError):
+ ParsedExpression('(a b c (d ...)')
+ with pytest.raises(EinopsError):
+ ParsedExpression('(a) (()) b c (d ...)')
+ with pytest.raises(EinopsError):
+ ParsedExpression('(a) ((b c) (d ...))')
+
+ # invalid identifiers
+ ParsedExpression('camelCase under_scored cApiTaLs ß ...')
+ with pytest.raises(EinopsError):
+ ParsedExpression('1a')
+ with pytest.raises(EinopsError):
+ ParsedExpression('_pre')
+ with pytest.raises(EinopsError):
+ ParsedExpression('...pre')
+ with pytest.raises(EinopsError):
+ ParsedExpression('pre...')
+
+
+def test_parse_expression():
+ parsed = ParsedExpression('a1 b1 c1 d1')
+ assert parsed.identifiers == {'a1', 'b1', 'c1', 'd1'}
+ assert parsed.composition == [['a1'], ['b1'], ['c1'], ['d1']]
+ assert not parsed.has_non_unitary_anonymous_axes
+ assert not parsed.has_ellipsis
+
+ parsed = ParsedExpression('() () () ()')
+ assert parsed.identifiers == set()
+ assert parsed.composition == [[], [], [], []]
+ assert not parsed.has_non_unitary_anonymous_axes
+ assert not parsed.has_ellipsis
+
+ parsed = ParsedExpression('1 1 1 ()')
+ assert parsed.identifiers == set()
+ assert parsed.composition == [[], [], [], []]
+ assert not parsed.has_non_unitary_anonymous_axes
+ assert not parsed.has_ellipsis
+
+ aap = AnonymousAxisPlaceholder
+
+ parsed = ParsedExpression('5 (3 4)')
+ assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5}
+ assert parsed.composition == [[aap(5)], [aap(3), aap(4)]]
+ assert parsed.has_non_unitary_anonymous_axes
+ assert not parsed.has_ellipsis
+
+ parsed = ParsedExpression('5 1 (1 4) 1')
+ assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5}
+ assert parsed.composition == [[aap(5)], [], [aap(4)], []]
+
+ parsed = ParsedExpression('name1 ... a1 12 (name2 14)')
+ assert len(parsed.identifiers) == 6
+ assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2
+ assert parsed.composition == [['name1'], _ellipsis, ['a1'], [aap(12)], ['name2', aap(14)]]
+ assert parsed.has_non_unitary_anonymous_axes
+ assert parsed.has_ellipsis
+ assert not parsed.has_ellipsis_parenthesized
+
+ parsed = ParsedExpression('(name1 ... a1 12) name2 14')
+ assert len(parsed.identifiers) == 6
+ assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2
+ assert parsed.composition == [['name1', _ellipsis, 'a1', aap(12)], ['name2'], [aap(14)]]
+ assert parsed.has_non_unitary_anonymous_axes
+ assert parsed.has_ellipsis
+ assert parsed.has_ellipsis_parenthesized
diff --git a/brainpy/_src/math/tests/test_others.py b/brainpy/_src/math/tests/test_others.py
new file mode 100644
index 000000000..084b8664d
--- /dev/null
+++ b/brainpy/_src/math/tests/test_others.py
@@ -0,0 +1,21 @@
+
+import brainpy.math as bm
+from scipy.special import exprel
+
+from unittest import TestCase
+
+
+class Test_exprel(TestCase):
+ def test1(self):
+ for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]:
+ print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}')
+ # self.assertEqual(exprel(x))
+
+ def test2(self):
+ bm.enable_x64()
+ for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]:
+ print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}')
+ # self.assertEqual(exprel(x))
+
+
+
diff --git a/brainpy/_src/math/tests/test_tifunc.py b/brainpy/_src/math/tests/test_tifunc.py
new file mode 100644
index 000000000..6823ebabd
--- /dev/null
+++ b/brainpy/_src/math/tests/test_tifunc.py
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+
+import jax
+import jax.numpy as jnp
+import pytest
+
+pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.")
+import brainpy.math as bm
+import taichi as ti
+import matplotlib.pyplot as plt
+import os
+
+
+bm.set_platform('cpu')
+
+
+def test_taichi_random():
+ @ti.kernel
+ def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
+ out: ti.types.ndarray(ndim=1, dtype=ti.f32)):
+ key = bm.tifunc.lfsr88_key(seed[0])
+ for i in range(out.shape[0]):
+ key, result = bm.tifunc.lfsr88_rand(key)
+ out[i] = result
+
+ @ti.kernel
+ def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ for i in range(out.shape[0]):
+ out[i] = bm.tifunc.taichi_lcg_rand(seed)
+
+ @ti.kernel
+ def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1),
+ low_high: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ key = bm.tifunc.lfsr88_key(seed[0])
+ low = low_high[0]
+ high = low_high[1]
+ for i in range(out.shape[0]):
+ key, out[i] = bm.tifunc.lfsr88_randint(key, low, high)
+
+ @ti.kernel
+ def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1),
+ low_high: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ key = bm.tifunc.lfsr88_key(seed[0])
+ low = low_high[0]
+ high = low_high[1]
+ for i in range(out.shape[0]):
+ key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high)
+
+ @ti.kernel
+ def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1),
+ mu_sigma: ti.types.ndarray(ndim=1),
+ out: ti.types.ndarray(ndim=1)):
+ key = bm.tifunc.lfsr88_key(seed[0])
+ mu = mu_sigma[0]
+ sigma = mu_sigma[1]
+
+ for i in range(out.shape[0]):
+ key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma)
+
+ n = 100000
+ seed = jnp.array([1234, ], dtype=jnp.uint32)
+ low_high = jnp.array([0, 10])
+ mu_sigma = jnp.array([0, 1])
+
+ prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88,
+ gpu_kernel=test_taichi_lfsr88)
+
+
+ prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand,
+ gpu_kernel=test_taichi_lcg_rand)
+ prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution,
+ gpu_kernel=test_taichi_uniform_int_distribution)
+ prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution,
+ gpu_kernel=test_taichi_uniform_real_distribution)
+ prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution,
+ gpu_kernel=test_taichi_normal_distribution)
+
+ file_path = os.path.dirname(os.path.abspath(__file__))
+
+ out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])
+ # show the distribution of out
+ plt.hist(out, bins=100)
+ plt.title("LFSR88 random number generator")
+ plt.savefig(file_path + "/lfsr88.png")
+ plt.close()
+
+ out = prim_lcg_rand(seed,
+ outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])
+ # show the distribution of out
+ plt.hist(out, bins=100)
+ plt.title("LCG random number generator")
+ plt.savefig(file_path + "/lcg_rand.png")
+ plt.close()
+
+ out = prim_uniform_int_distribution(seed, low_high,
+ outs=[jax.ShapeDtypeStruct((n,), jnp.int32)])
+ # show the distribution of out
+ plt.hist(out, bins=10)
+ plt.title("Uniform int distribution (0, 10)")
+ plt.savefig(file_path + "/uniform_int_distribution.png")
+ plt.close()
+
+ out = prim_uniform_real_distribution(seed, low_high,
+ outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])
+ # show the distribution of out
+ plt.hist(out, bins=100)
+ plt.title("Uniform real distribution (0, 10)")
+ plt.savefig(file_path + "/uniform_real_distribution.png")
+ plt.close()
+
+ out = prim_normal_distribution(seed, mu_sigma,
+ outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])
+ # show the distribution of out
+ plt.title("Normal distribution mu=0, sigma=1")
+ plt.hist(out, bins=100)
+ plt.savefig(file_path + "/normal_distribution.png")
+
+
+# TODO; test default types
diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py
new file mode 100644
index 000000000..a9ee39f4a
--- /dev/null
+++ b/brainpy/_src/math/tifunc.py
@@ -0,0 +1,364 @@
+from brainpy._src.dependency_check import import_taichi
+from . import defaults
+
+ti = import_taichi()
+
+__all__ = [
+ # taichi function for other utilities
+ 'warp_reduce_sum',
+
+ # taichi functions for random number generator with LFSR88 algorithm
+ 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn',
+ 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand',
+
+ # taichi functions for random number generator with LFSR113 algorithm
+ 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn',
+ 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand',
+]
+
+
+@ti.func
+def _lcg_rand(state: ti.types.ndarray(ndim=1)):
+ # LCG constants
+ state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223)
+ return state[0]
+
+
+@ti.func
+def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)):
+ """
+ Generate a random number using the Taichi LCG algorithm.
+
+ Parameters:
+ seed (ti.types.ndarray): The seed value for the random number generator.
+
+ Returns:
+ float: A random number between 0 and 1.
+ """
+
+ return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1)
+
+
+#############################################
+# Random Number Generator: LFSR88 algorithm #
+#############################################
+
+
+@ti.func
+def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32):
+ """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer).
+
+ This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``.
+
+ Source:
+ https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c
+
+ /**** VERY IMPORTANT **** :
+ The initial seeds s1, s2, s3 MUST be larger than
+ 1, 7, and 15 respectively.
+ */
+
+ Args:
+ seed: int. The seed value for the random number generator.
+
+ Returns:
+ ti.math.uvec4: The random key for the LFSR88 random number generator.
+ """
+ return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0))
+
+
+@ti.func
+def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32):
+ """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer).
+
+ Args:
+ key: The state value for the random number generator.
+
+ Returns:
+ ti.math.uvec4: The next random key.
+ """
+ b = ti.u32(((key[0] << 13) ^ key[0]) >> 19)
+ s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b
+ b = ((key[1] << 2) ^ key[1]) >> 25
+ s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b
+ b = ((key[2] << 3) ^ key[2]) >> 11
+ s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b
+ return ti.math.uvec4(s1, s2, s3, b)
+
+
+@ti.func
+def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10):
+ """
+ Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm.
+
+ Args:
+ key: The state value for the random number generator.
+ mu: The mean of the normal distribution.
+ sigma: The standard deviation of the normal distribution.
+ epsilon: The epsilon value to avoid log(0).
+ """
+
+ key, r = lfsr88_randn(key, epsilon)
+ return key, mu + sigma * r
+
+
+@ti.func
+def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10):
+ """
+ Generate a random number with the standard normal distribution using the LFSR88 algorithm.
+
+ Args:
+ key: The state value for the random number generator.
+ epsilon: The epsilon value to avoid log(0).
+
+ References:
+ Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
+ Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method
+
+ """
+
+ key, u1 = lfsr88_rand(key)
+ key, u2 = lfsr88_rand(key)
+
+ # Ensure state1 is not zero to avoid log(0)
+ u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float)
+
+ # Normalize the uniform samples
+ mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float)
+
+ # Box-Muller transform
+ # z1 = mag * ti.cos(2 * ti.math.pi * u2)
+ z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float)
+
+ return key, z2
+
+
+@ti.func
+def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high):
+ """
+ Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm.
+
+ Parameters:
+ key: The state value used for random number generation.
+ low: The lower bound of the range.
+ high: The upper bound of the range.
+ """
+ key = lfsr88_next_key(key)
+ return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int)
+
+
+@ti.func
+def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32):
+ key = lfsr88_next_key(key)
+ return key, dtype(key[0] ^ key[1] ^ key[2])
+
+
+@ti.func
+def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high):
+ """
+ Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm.
+
+ Args:
+ key: The state value used for random number generation.
+ low: The lower bound of the range.
+ high: The upper bound of the range.
+ """
+ key = lfsr88_next_key(key)
+ r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
+ return key, ti.cast(r * (high - low) + low, defaults.ti_float)
+
+
+@ti.func
+def lfsr88_rand(key: ti.types.vector(4, ti.u32)):
+ """
+ Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm.
+
+ Args:
+ key: The state value used for random number generation.
+ """
+ key = lfsr88_next_key(key)
+ return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
+
+
+##############################################
+# Random Number Generator: LFSR113 algorithm #
+##############################################
+
+
+@ti.func
+def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32):
+ """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer).
+
+ This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``.
+
+ Source:
+ https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c
+
+ /**** VERY IMPORTANT **** :
+ The initial seeds s1, s2, s3, s4 MUST be larger than
+ 1, 7, 15, and 127 respectively.
+ */
+
+ Args:
+ seed: int. The seed value for the random number generator.
+
+ Returns:
+ ti.math.uvec4: The random key for the LFSR113 random number generator.
+ """
+ return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127))
+
+
+@ti.func
+def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32):
+ """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer).
+
+ Args:
+ key: The state value for the random number generator.
+
+ Returns:
+ ti.math.uvec4: The next random key.
+ """
+ z1 = key[0]
+ z2 = key[1]
+ z3 = key[2]
+ z4 = key[3]
+ b = ((z1 << 6) ^ z1) >> 13
+ z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b)
+ b = ((z2 << 2) ^ z2) >> 27
+ z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b)
+ b = ((z3 << 13) ^ z3) >> 21
+ z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b)
+ b = ((z4 << 3) ^ z4) >> 12
+ z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b)
+ return ti.math.uvec4(z1, z2, z3, z4)
+
+
+@ti.func
+def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10):
+ """
+ Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm.
+
+ Args:
+ key: The state value for the random number generator.
+ mu: The mean of the normal distribution.
+ sigma: The standard deviation of the normal distribution.
+ epsilon: The epsilon value to avoid log(0).
+ """
+
+ key, r = lfsr113_randn(key, epsilon)
+ return key, ti.cast(mu + sigma * r, defaults.ti_float)
+
+
+@ti.func
+def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10):
+ """
+ Generate a random number with standard normal distribution using the LFSR113 algorithm.
+
+ Args:
+ key: The state value for the random number generator.
+ epsilon: The epsilon value to avoid log(0).
+
+ References:
+ Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
+ Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method
+
+ """
+
+ key, u1 = lfsr113_rand(key)
+ key, u2 = lfsr113_rand(key)
+
+ # Ensure state1 is not zero to avoid log(0)
+ u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float)
+
+ # Normalize the uniform samples
+ mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float)
+
+ # Box-Muller transform
+ # z1 = mag * ti.cos(2 * ti.math.pi * u2)
+ z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float)
+
+ return key, z2
+
+
+@ti.func
+def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high):
+ """
+ Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm.
+
+ Parameters:
+ key: The state value used for random number generation.
+ low: The lower bound of the range.
+ high: The upper bound of the range.
+ """
+ key = lfsr113_next_key(key)
+ return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int)
+
+
+@ti.func
+def lfsr113_randint(key: ti.types.vector(4, ti.u32)):
+ key = lfsr113_next_key(key)
+ return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int)
+
+
+@ti.func
+def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high):
+ """
+ Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm.
+
+ Args:
+ key: The state value used for random number generation.
+ low: The lower bound of the range.
+ high: The upper bound of the range.
+ """
+ key = lfsr88_next_key(key)
+ r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
+ return key, ti.cast(r * (high - low) + low, defaults.ti_float)
+
+
+@ti.func
+def lfsr113_rand(key: ti.types.vector(4, ti.u32)):
+ """
+ Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm.
+
+ Args:
+ key: The state value used for random number generation.
+ """
+ key = lfsr113_next_key(key)
+ return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
+
+
+###########################
+# Reductions: warp reduce #
+###########################
+
+
+@ti.func
+def warp_reduce_sum_all(val):
+ """
+ Warp reduce sum.
+
+ Args:
+ val (float): The value to be reduced.
+
+ Returns:
+ float: The reduced value.
+ """
+ for i in ti.static(range(1, 32)):
+ val += ti.static(ti.simt.warp.shfl_xor(val, i))
+ return val
+
+
+@ti.func
+def warp_reduce_sum(val):
+ """
+ Warp reduce sum.
+
+ Args:
+ val (float): The value to be reduced.
+
+ Returns:
+ float: The reduced value.
+ """
+ for offset in ti.static((16, 8, 4, 2, 1)):
+ val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset)
+ return val
diff --git a/brainpy/_src/measure/lfp.py b/brainpy/_src/measure/lfp.py
index 0662be8d9..434666efb 100644
--- a/brainpy/_src/measure/lfp.py
+++ b/brainpy/_src/measure/lfp.py
@@ -10,7 +10,7 @@
]
-def unitary_LFP(times, spikes, spike_type='exc',
+def unitary_LFP(times, spikes, spike_type,
xmax=0.2, ymax=0.2, va=200., lambda_=0.2,
sig_i=2.1, sig_e=2.1 * 1.5, location='soma layer', seed=None):
"""A kernel-based method to calculate unitary local field potentials (uLFP)
diff --git a/brainpy/_src/measure/tests/test_correlation.py b/brainpy/_src/measure/tests/test_correlation.py
index 950dbce1f..dd19ca8aa 100644
--- a/brainpy/_src/measure/tests/test_correlation.py
+++ b/brainpy/_src/measure/tests/test_correlation.py
@@ -1,110 +1,111 @@
-# -*- coding: utf-8 -*-
-
-
-import unittest
-from functools import partial
-
-from jax import jit
-
-import brainpy as bp
-import brainpy.math as bm
-
-
-class TestCrossCorrelation(unittest.TestCase):
- def test_c(self):
- bm.random.seed()
- spikes = bm.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T
- cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.)
- f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.))
- cc2 = f_cc(spikes)
- print(cc1, cc2)
- self.assertTrue(cc1 == cc2)
- bm.clear_buffer_memory()
-
- def test_cc(self):
- bm.random.seed()
- spikes = bm.ones((1000, 10))
- cc1 = bp.measure.cross_correlation(spikes, 1.)
- self.assertTrue(cc1 == 1.)
-
- spikes = bm.zeros((1000, 10))
- cc2 = bp.measure.cross_correlation(spikes, 1.)
- self.assertTrue(cc2 == 0.)
-
- bm.clear_buffer_memory()
-
- def test_cc2(self):
- bm.random.seed()
- spikes = bm.random.randint(0, 2, (1000, 10))
- print(bp.measure.cross_correlation(spikes, 1.))
- print(bp.measure.cross_correlation(spikes, 0.5))
- bm.clear_buffer_memory()
-
- def test_cc3(self):
- bm.random.seed()
- spikes = bm.random.random((1000, 100)) < 0.8
- print(bp.measure.cross_correlation(spikes, 1.))
- print(bp.measure.cross_correlation(spikes, 0.5))
- bm.clear_buffer_memory()
-
- def test_cc4(self):
- bm.random.seed()
- spikes = bm.random.random((1000, 100)) < 0.2
- print(bp.measure.cross_correlation(spikes, 1.))
- print(bp.measure.cross_correlation(spikes, 0.5))
- bm.clear_buffer_memory()
-
- def test_cc5(self):
- bm.random.seed()
- spikes = bm.random.random((1000, 100)) < 0.05
- print(bp.measure.cross_correlation(spikes, 1.))
- print(bp.measure.cross_correlation(spikes, 0.5))
- bm.clear_buffer_memory()
-
-
-class TestVoltageFluctuation(unittest.TestCase):
- def test_vf1(self):
- bm.random.seed()
- voltages = bm.random.normal(0, 10, size=(100, 10))
- print(bp.measure.voltage_fluctuation(voltages))
-
- bm.enable_x64()
- voltages = bm.ones((100, 10))
- r1 = bp.measure.voltage_fluctuation(voltages)
-
- jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False))
- jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False))
- r2 = jit_f(voltages)
- print(r1, r2) # TODO: JIT results are different?
- # self.assertTrue(r1 == r2)
-
- bm.disable_x64()
- bm.clear_buffer_memory()
-
-
-class TestFunctionalConnectivity(unittest.TestCase):
- def test_cf1(self):
- bm.random.seed()
- act = bm.random.random((10000, 3))
- r1 = bp.measure.functional_connectivity(act)
-
- jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False))
- r2 = jit_f(act)
-
- self.assertTrue(bm.allclose(r1, r2))
- bm.clear_buffer_memory()
-
-
-class TestMatrixCorrelation(unittest.TestCase):
- def test_mc(self):
- bm.random.seed()
- A = bm.random.random((100, 100))
- B = bm.random.random((100, 100))
- r1 = (bp.measure.matrix_correlation(A, B))
-
- jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False))
- r2 = jit_f(A, B)
- self.assertTrue(bm.allclose(r1, r2))
- bm.clear_buffer_memory()
-
-
+# -*- coding: utf-8 -*-
+
+
+import unittest
+from functools import partial
+
+from jax import jit
+
+import brainpy as bp
+import brainpy.math as bm
+
+bm.set_platform('cpu')
+
+class TestCrossCorrelation(unittest.TestCase):
+ def test_c(self):
+ bm.random.seed()
+ spikes = bm.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T
+ cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.)
+ f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.))
+ cc2 = f_cc(spikes)
+ print(cc1, cc2)
+ self.assertTrue(cc1 == cc2)
+ bm.clear_buffer_memory()
+
+ def test_cc(self):
+ bm.random.seed()
+ spikes = bm.ones((1000, 10))
+ cc1 = bp.measure.cross_correlation(spikes, 1.)
+ self.assertTrue(cc1 == 1.)
+
+ spikes = bm.zeros((1000, 10))
+ cc2 = bp.measure.cross_correlation(spikes, 1.)
+ self.assertTrue(cc2 == 0.)
+
+ bm.clear_buffer_memory()
+
+ def test_cc2(self):
+ bm.random.seed()
+ spikes = bm.random.randint(0, 2, (1000, 10))
+ print(bp.measure.cross_correlation(spikes, 1.))
+ print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
+
+ def test_cc3(self):
+ bm.random.seed()
+ spikes = bm.random.random((1000, 100)) < 0.8
+ print(bp.measure.cross_correlation(spikes, 1.))
+ print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
+
+ def test_cc4(self):
+ bm.random.seed()
+ spikes = bm.random.random((1000, 100)) < 0.2
+ print(bp.measure.cross_correlation(spikes, 1.))
+ print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
+
+ def test_cc5(self):
+ bm.random.seed()
+ spikes = bm.random.random((1000, 100)) < 0.05
+ print(bp.measure.cross_correlation(spikes, 1.))
+ print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
+
+
+class TestVoltageFluctuation(unittest.TestCase):
+ def test_vf1(self):
+ bm.random.seed()
+ voltages = bm.random.normal(0, 10, size=(100, 10))
+ print(bp.measure.voltage_fluctuation(voltages))
+
+ bm.enable_x64()
+ voltages = bm.ones((100, 10))
+ r1 = bp.measure.voltage_fluctuation(voltages)
+
+ jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False))
+ jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False))
+ r2 = jit_f(voltages)
+ print(r1, r2) # TODO: JIT results are different?
+ # self.assertTrue(r1 == r2)
+
+ bm.disable_x64()
+ bm.clear_buffer_memory()
+
+
+class TestFunctionalConnectivity(unittest.TestCase):
+ def test_cf1(self):
+ bm.random.seed()
+ act = bm.random.random((10000, 3))
+ r1 = bp.measure.functional_connectivity(act)
+
+ jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False))
+ r2 = jit_f(act)
+
+ self.assertTrue(bm.allclose(r1, r2))
+ bm.clear_buffer_memory()
+
+
+class TestMatrixCorrelation(unittest.TestCase):
+ def test_mc(self):
+ bm.random.seed()
+ A = bm.random.random((100, 100))
+ B = bm.random.random((100, 100))
+ r1 = (bp.measure.matrix_correlation(A, B))
+
+ jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False))
+ r2 = jit_f(A, B)
+ self.assertTrue(bm.allclose(r1, r2))
+ bm.clear_buffer_memory()
+
+
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index 8ea8a5216..323fe872c 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -21,7 +21,6 @@
DynamicalSystem = None
delay_identifier, init_delay_by_return = None, None
-
__all__ = [
'MixIn',
'ParamDesc',
@@ -53,7 +52,6 @@ def _get_dynsys():
return DynamicalSystem
-
class MixIn(object):
"""Base MixIn object.
@@ -378,63 +376,131 @@ def get_delay_var(self, name):
class SupportInputProj(MixIn):
"""The :py:class:`~.MixIn` that receives the input projections.
- Note that the subclass should define a ``cur_inputs`` attribute.
+ Note that the subclass should define a ``cur_inputs`` attribute. Otherwise,
+ the input function utilities cannot be used.
"""
- cur_inputs: bm.node_dict
+ current_inputs: bm.node_dict
+ delta_inputs: bm.node_dict
- def add_inp_fun(self, key: Any, fun: Callable):
+ def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'):
"""Add an input function.
Args:
- key: The dict key.
- fun: The function to generate inputs.
+ key: str. The dict key.
+ fun: Callable. The function to generate inputs.
+ label: str. The input label.
+ category: str. The input category, should be ``current`` (the current) or
+ ``delta`` (the delta synapse, indicating the delta function).
"""
if not callable(fun):
raise TypeError('Must be a function.')
- if key in self.cur_inputs:
- raise ValueError(f'Key "{key}" has been defined and used.')
- self.cur_inputs[key] = fun
- def get_inp_fun(self, key):
+ key = self._input_label_repr(key, label)
+ if category == 'current':
+ if key in self.current_inputs:
+ raise ValueError(f'Key "{key}" has been defined and used.')
+ self.current_inputs[key] = fun
+ elif category == 'delta':
+ if key in self.delta_inputs:
+ raise ValueError(f'Key "{key}" has been defined and used.')
+ self.delta_inputs[key] = fun
+ else:
+ raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".')
+
+ def get_inp_fun(self, key: str):
"""Get the input function.
Args:
- key: The key.
+ key: str. The key.
Returns:
The input function which generates currents.
"""
- return self.cur_inputs.get(key)
+ if key in self.current_inputs:
+ return self.current_inputs[key]
+ elif key in self.delta_inputs:
+ return self.delta_inputs[key]
+ else:
+ raise ValueError(f'Unknown key: {key}')
- def sum_inputs(self, *args, init=0., label=None, **kwargs):
- """Summarize all inputs by the defined input functions ``.cur_inputs``.
+ def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
+ """Summarize all current inputs by the defined input functions ``.current_inputs``.
Args:
*args: The arguments for input functions.
init: The initial input data.
+ label: str. The input label.
**kwargs: The arguments for input functions.
Returns:
The total currents.
"""
if label is None:
- for key, out in self.cur_inputs.items():
+ for key, out in self.current_inputs.items():
init = init + out(*args, **kwargs)
else:
- for key, out in self.cur_inputs.items():
- if key.startswith(label + ' // '):
+ label_repr = self._input_label_start(label)
+ for key, out in self.current_inputs.items():
+ if key.startswith(label_repr):
init = init + out(*args, **kwargs)
return init
+ def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
+ """Summarize all delta inputs by the defined input functions ``.delta_inputs``.
+
+ Args:
+ *args: The arguments for input functions.
+ init: The initial input data.
+ label: str. The input label.
+ **kwargs: The arguments for input functions.
+
+ Returns:
+ The total currents.
+ """
+ if label is None:
+ for key, out in self.delta_inputs.items():
+ init = init + out(*args, **kwargs)
+ else:
+ label_repr = self._input_label_start(label)
+ for key, out in self.delta_inputs.items():
+ if key.startswith(label_repr):
+ init = init + out(*args, **kwargs)
+ return init
-class SupportAutoDelay(MixIn):
+ @classmethod
+ def _input_label_start(cls, label: str):
+ # unify the input label repr.
+ return f'{label} // '
+
+ @classmethod
+ def _input_label_repr(cls, name: str, label: Optional[str] = None):
+ # unify the input label repr.
+ return name if label is None else (cls._input_label_start(label) + str(name))
+
+ # deprecated #
+ # ---------- #
+
+ @property
+ def cur_inputs(self):
+ return self.current_inputs
+
+ def sum_inputs(self, *args, **kwargs):
+ warnings.warn('Please use ".sum_current_inputs()" instead. ".sum_inputs()" will be removed.', UserWarning)
+ return self.sum_current_inputs(*args, **kwargs)
+
+
+class SupportReturnInfo(MixIn):
"""``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""
def return_info(self) -> Union[bm.Variable, ReturnInfo]:
raise NotImplementedError('Must implement the "return_info()" function.')
+class SupportAutoDelay(SupportReturnInfo):
+ pass
+
+
class SupportOnline(MixIn):
""":py:class:`~.MixIn` to support the online training methods.
@@ -519,19 +585,6 @@ def __subclasscheck__(self, subclass):
return all([issubclass(subclass, cls) for cls in self.__bases__])
-class UnionType2(MixIn):
- """Union type for multiple types.
-
- >>> import brainpy as bp
- >>>
- >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay])
- """
-
- @classmethod
- def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type:
- return _MetaUnionType('UnionType', types, {})
-
-
if sys.version_info.minor > 8:
class _JointGenericAlias(_UnionGenericAlias, _root=True):
def __subclasscheck__(self, subclass):
diff --git a/brainpy/_src/optimizers/tests/test_ModifyLr.py b/brainpy/_src/optimizers/tests/test_ModifyLr.py
index 6e3cbf8c0..01e51016e 100644
--- a/brainpy/_src/optimizers/tests/test_ModifyLr.py
+++ b/brainpy/_src/optimizers/tests/test_ModifyLr.py
@@ -1,7 +1,8 @@
+from absl.testing import absltest
+from absl.testing import parameterized
+
import brainpy as bp
import brainpy.math as bm
-from absl.testing import parameterized
-from absl.testing import absltest
dt = 0.04
num_step = int(1.0 / dt)
@@ -33,15 +34,10 @@ def __init__(self, num_in, num_hidden):
def update(self, x):
return self.out(self.rnn(x))
-
-with bm.training_environment():
- model = RNN(1, 100)
-
-
-def loss(predictions, targets, l2_reg=2e-4):
- mse = bp.losses.mean_squared_error(predictions, targets)
- l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
- return mse + l2
+ def loss(self, predictions, targets, l2_reg=2e-4):
+ mse = bp.losses.mean_squared_error(predictions, targets)
+ l2 = l2_reg * bp.losses.l2_norm(self.train_vars().unique().dict()) ** 2
+ return mse + l2
class test_ModifyLr(parameterized.TestCase):
@@ -54,22 +50,28 @@ class test_ModifyLr(parameterized.TestCase):
]
)
def test_NewScheduler(self, LearningRate):
+ with bm.training_environment():
+ model = RNN(1, 100)
+
opt = bp.optim.Adam(lr=LearningRate, eps=1e-1)
- trainer = bp.BPTT(model, loss_fun=loss, optimizer=opt)
+ trainer = bp.BPTT(model, loss_fun=model.loss, optimizer=opt)
bm.clear_buffer_memory()
def test_modifylr(self):
+ with bm.training_environment():
+ model = RNN(1, 100)
+
Scheduler_lr = bp.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt1 = bp.optim.Adam(lr=Scheduler_lr, eps=1e-1)
opt1.lr.lr = 0.01
- trainer1 = bp.BPTT(model, loss_fun=loss, optimizer=opt1)
+ trainer1 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt1)
bm.clear_buffer_memory()
opt2 = bp.optim.SGD(lr=Scheduler_lr)
opt2.lr.set_value(0.01)
- trainer2 = bp.BPTT(model, loss_fun=loss, optimizer=opt2)
+ trainer2 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt2)
bm.clear_buffer_memory()
diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py
index f652217d9..e3eebe510 100644
--- a/brainpy/_src/running/pathos_multiprocessing.py
+++ b/brainpy/_src/running/pathos_multiprocessing.py
@@ -136,7 +136,7 @@ def cpu_ordered_parallel(
>>>
>>> def simulate(inp):
>>> inp = bm.as_jax(inp)
- >>> hh = bp.neurons.HH(1)
+ >>> hh = bp.dyn.HH(1)
>>> runner = bp.DSRunner(hh, inputs=['input', inp],
>>> monitors=['V', 'spike'],
>>> progress_bar=False)
@@ -194,7 +194,7 @@ def cpu_unordered_parallel(
>>>
>>> def simulate(inp):
>>> inp = bm.as_jax(inp)
- >>> hh = bp.neurons.HH(1)
+ >>> hh = bp.dyn.HH(1)
>>> runner = bp.DSRunner(hh, inputs=['input', inp],
>>> monitors=['V', 'spike'],
>>> progress_bar=False)
diff --git a/brainpy/_src/tests/test_dynsys.py b/brainpy/_src/tests/test_dynsys.py
index b7a2ebdab..f8605380e 100644
--- a/brainpy/_src/tests/test_dynsys.py
+++ b/brainpy/_src/tests/test_dynsys.py
@@ -1,3 +1,4 @@
+import unittest
import brainpy as bp
@@ -36,5 +37,19 @@ def update(self, tdi, x=None):
B()(1.)
+class TestResetLevelDecorator(unittest.TestCase):
+ _max_level = 10 # Define the maximum level for testing purposes
+ @bp.reset_level(5)
+ def test_function_with_reset_level_5(self):
+ self.assertEqual(self.test_function_with_reset_level_5.reset_level, 5)
+ def test1(self):
+ with self.assertRaises(ValueError):
+ @bp.reset_level(12) # This should raise a ValueError
+ def test_function_with_invalid_reset_level(self):
+ pass # Call the function here to trigger the ValueError
+
+ @bp.reset_level(-3)
+ def test_function_with_negative_reset_level(self):
+ self.assertEqual(self.test_function_with_negative_reset_level.reset_level, self._max_level - 3)
diff --git a/brainpy/_src/tests/test_helper.py b/brainpy/_src/tests/test_helper.py
new file mode 100644
index 000000000..d8c85010b
--- /dev/null
+++ b/brainpy/_src/tests/test_helper.py
@@ -0,0 +1,30 @@
+import brainpy as bp
+
+import unittest
+
+
+class TestResetLevel(unittest.TestCase):
+
+ def test1(self):
+ class Level0(bp.DynamicalSystem):
+ @bp.reset_level(0)
+ def reset_state(self, *args, **kwargs):
+ print('Level 0')
+
+ class Level1(bp.DynamicalSystem):
+ @bp.reset_level(1)
+ def reset_state(self, *args, **kwargs):
+ print('Level 1')
+
+ class Net(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+ self.l0 = Level0()
+ self.l1 = Level1()
+ self.l0_2 = Level0()
+ self.l1_2 = Level1()
+
+ net = Net()
+ net.reset()
+
+
diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py
new file mode 100644
index 000000000..cbc710dba
--- /dev/null
+++ b/brainpy/_src/tools/functions.py
@@ -0,0 +1,192 @@
+import inspect
+from functools import partial
+from operator import attrgetter
+from types import MethodType
+
+__all__ = [
+ 'compose', 'pipe'
+]
+
+
+def identity(x):
+ """ Identity function. Return x
+
+ >>> identity(3)
+ 3
+ """
+ return x
+
+
+def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None):
+ """ Like @property, but returns ``classval`` when used as a class attribute
+
+ >>> class MyClass(object):
+ ... '''The class docstring'''
+ ... @instanceproperty(classval=__doc__)
+ ... def __doc__(self):
+ ... return 'An object docstring'
+ ... @instanceproperty
+ ... def val(self):
+ ... return 42
+ ...
+ >>> MyClass.__doc__
+ 'The class docstring'
+ >>> MyClass.val is None
+ True
+ >>> obj = MyClass()
+ >>> obj.__doc__
+ 'An object docstring'
+ >>> obj.val
+ 42
+ """
+ if fget is None:
+ return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc,
+ classval=classval)
+ return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc,
+ classval=classval)
+
+
+class InstanceProperty(property):
+ """ Like @property, but returns ``classval`` when used as a class attribute
+
+ Should not be used directly. Use ``instanceproperty`` instead.
+ """
+
+ def __init__(self, fget=None, fset=None, fdel=None, doc=None,
+ classval=None):
+ self.classval = classval
+ property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc)
+
+ def __get__(self, obj, type=None):
+ if obj is None:
+ return self.classval
+ return property.__get__(self, obj, type)
+
+ def __reduce__(self):
+ state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval)
+ return InstanceProperty, state
+
+
+class Compose(object):
+ """ A composition of functions
+
+ See Also:
+ compose
+ """
+ __slots__ = 'first', 'funcs'
+
+ def __init__(self, funcs):
+ funcs = tuple(reversed(funcs))
+ self.first = funcs[0]
+ self.funcs = funcs[1:]
+
+ def __call__(self, *args, **kwargs):
+ ret = self.first(*args, **kwargs)
+ for f in self.funcs:
+ ret = f(ret)
+ return ret
+
+ def __getstate__(self):
+ return self.first, self.funcs
+
+ def __setstate__(self, state):
+ self.first, self.funcs = state
+
+ @instanceproperty(classval=__doc__)
+ def __doc__(self):
+ def composed_doc(*fs):
+ """Generate a docstring for the composition of fs.
+ """
+ if not fs:
+ # Argument name for the docstring.
+ return '*args, **kwargs'
+
+ return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:]))
+
+ try:
+ return (
+ 'lambda *args, **kwargs: ' +
+ composed_doc(*reversed((self.first,) + self.funcs))
+ )
+ except AttributeError:
+ # One of our callables does not have a `__name__`, whatever.
+ return 'A composition of functions'
+
+ @property
+ def __name__(self):
+ try:
+ return '_of_'.join(
+ (f.__name__ for f in reversed((self.first,) + self.funcs))
+ )
+ except AttributeError:
+ return type(self).__name__
+
+ def __repr__(self):
+ return '{.__class__.__name__}{!r}'.format(
+ self, tuple(reversed((self.first,) + self.funcs)))
+
+ def __eq__(self, other):
+ if isinstance(other, Compose):
+ return other.first == self.first and other.funcs == self.funcs
+ return NotImplemented
+
+ def __ne__(self, other):
+ equality = self.__eq__(other)
+ return NotImplemented if equality is NotImplemented else not equality
+
+ def __hash__(self):
+ return hash(self.first) ^ hash(self.funcs)
+
+ # Mimic the descriptor behavior of python functions.
+ # i.e. let Compose be called as a method when bound to a class.
+ # adapted from
+ # docs.python.org/3/howto/descriptor.html#functions-and-methods
+ def __get__(self, obj, objtype=None):
+ return self if obj is None else MethodType(self, obj)
+
+ # introspection with Signature is only possible from py3.3+
+ @instanceproperty
+ def __signature__(self):
+ base = inspect.signature(self.first)
+ last = inspect.signature(self.funcs[-1])
+ return base.replace(return_annotation=last.return_annotation)
+
+ __wrapped__ = instanceproperty(attrgetter('first'))
+
+
+def compose(*funcs):
+ """ Compose functions to operate in series.
+
+ Returns a function that applies other functions in sequence.
+
+ Functions are applied from right to left so that
+ ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``.
+
+ If no arguments are provided, the identity function (f(x) = x) is returned.
+
+ >>> inc = lambda i: i + 1
+ >>> compose(str, inc)(3)
+ '4'
+ """
+ if not funcs:
+ return identity
+ if len(funcs) == 1:
+ return funcs[0]
+ else:
+ return Compose(funcs)
+
+
+def pipe(*funcs):
+ """ Pipe a value through a sequence of functions
+
+ I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))``
+
+ We think of the value as progressing through a pipe of several
+ transformations, much like pipes in UNIX
+
+
+ >>> double = lambda i: 2 * i
+ >>> pipe(double, str)(3)
+ '6'
+ """
+ return compose(*reversed(funcs))
diff --git a/brainpy/_src/tools/install.py b/brainpy/_src/tools/install.py
index aadf0f5c0..68981a5ec 100644
--- a/brainpy/_src/tools/install.py
+++ b/brainpy/_src/tools/install.py
@@ -8,19 +8,11 @@
BrainPy needs jaxlib, please install it.
-1. If you are using Windows system, install jaxlib through
+1. If you are using brainpy on CPU platform, please install jaxlib through
- >>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html
+ >>> pip install jaxlib
-2. If you are using macOS platform, install jaxlib through
-
- >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
-
-3. If you are using Linux platform, install jaxlib through
-
- >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
-
-4. If you are using Linux + CUDA platform, install jaxlib through
+2. If you are using Linux + CUDA platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
@@ -29,10 +21,3 @@
For more detail installation instructions, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
'''
-
-
-brainpylib_install = '''
-
-'''
-
-
diff --git a/brainpy/_src/tools/package.py b/brainpy/_src/tools/package.py
index 0da2dd7ae..c459ecfac 100644
--- a/brainpy/_src/tools/package.py
+++ b/brainpy/_src/tools/package.py
@@ -9,7 +9,6 @@
__all__ = [
- 'import_numba',
'numba_jit',
'numba_seed',
'numba_range',
@@ -17,12 +16,6 @@
]
-def import_numba():
- if numba is None:
- raise ModuleNotFoundError('Numba is needed. Please install numba through:\n\n'
- '> pip install numba')
- return numba
-
SUPPORT_NUMBA = numba is not None
diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py
new file mode 100644
index 000000000..c285e561a
--- /dev/null
+++ b/brainpy/_src/tools/tests/test_functions.py
@@ -0,0 +1,24 @@
+
+import unittest
+
+import brainpy as bp
+import brainpy.math as bm
+
+
+class TestFunction(unittest.TestCase):
+ def test_compose(self):
+ f = lambda a: a + 1
+ g = lambda a: a * 10
+ fun1 = bp.tools.compose(f, g)
+ fun2 = bp.tools.pipe(g, f)
+
+ arr = bm.random.randn(10)
+ r1 = fun1(arr)
+ r2 = fun2(arr)
+ groundtruth = f(g(arr))
+ self.assertTrue(bm.allclose(r1, r2))
+ self.assertTrue(bm.allclose(r1, groundtruth))
+ bm.clear_buffer_memory()
+
+
+
diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py
index f395158c0..6809d7125 100644
--- a/brainpy/_src/train/back_propagation.py
+++ b/brainpy/_src/train/back_propagation.py
@@ -278,7 +278,7 @@ def fit(
for x, y in _training_data:
# reset state
if reset_state:
- self.target.reset_state(self._get_input_batch_size(x))
+ self.target.reset(self._get_input_batch_size(x))
self.reset_state()
# training
@@ -356,7 +356,7 @@ def fit(
for x, y in _testing_data:
# reset state
if reset_state:
- self.target.reset_state(self._get_input_batch_size(x))
+ self.target.reset(self._get_input_batch_size(x))
self.reset_state()
# testing
@@ -604,7 +604,7 @@ def predict(
# reset the model states
if reset_state:
- self.target.reset_state(self._get_input_batch_size(xs=inputs))
+ self.target.reset(self._get_input_batch_size(xs=inputs))
self.reset_state()
# init monitor
for key in self._monitors.keys():
diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py
index 212a22617..d80764f26 100644
--- a/brainpy/_src/train/online.py
+++ b/brainpy/_src/train/online.py
@@ -161,7 +161,7 @@ def fit(
# reset the model states
if reset_state:
num_batch = self._get_input_batch_size(xs)
- self.target.reset_state(num_batch)
+ self.target.reset(num_batch)
self.reset_state()
# format input/target data
diff --git a/brainpy/_src/transform.py b/brainpy/_src/transform.py
index c9a8e4b13..cc20c6686 100644
--- a/brainpy/_src/transform.py
+++ b/brainpy/_src/transform.py
@@ -275,7 +275,6 @@ def __call__(
return results
def reset_state(self, batch_size=None):
- self.target.reset_state(batch_size)
if self.i0 is not None:
self.i0.value = bm.as_jax(self._i0)
if self.t0 is not None:
diff --git a/brainpy/check.py b/brainpy/check.py
index a1c780106..fafc0551d 100644
--- a/brainpy/check.py
+++ b/brainpy/check.py
@@ -41,7 +41,7 @@
'is_all_objs',
'jit_error',
'jit_error_checking',
- 'jit_error2',
+ 'jit_error_checking_no_args',
'serialize_kwargs',
]
@@ -349,13 +349,13 @@ def is_float(
if not isinstance(value, (float, np.floating)):
raise ValueError(f'{name} must be a float, but got {type(value)}')
if min_bound is not None:
- jit_error2(value < min_bound,
- ValueError(f"{name} must be a float bigger than {min_bound}, "
+ jit_error_checking_no_args(value < min_bound,
+ ValueError(f"{name} must be a float bigger than {min_bound}, "
f"while we got {value}"))
if max_bound is not None:
- jit_error2(value > max_bound,
- ValueError(f"{name} must be a float smaller than {max_bound}, "
+ jit_error_checking_no_args(value > max_bound,
+ ValueError(f"{name} must be a float smaller than {max_bound}, "
f"while we got {value}"))
return value
@@ -387,12 +387,12 @@ def is_integer(value: int, name=None, min_bound=None, max_bound=None, allow_none
else:
raise ValueError(f'{name} must be an int, but got {value}')
if min_bound is not None:
- jit_error2(jnp.any(value < min_bound),
- ValueError(f"{name} must be an int bigger than {min_bound}, "
+ jit_error_checking_no_args(jnp.any(value < min_bound),
+ ValueError(f"{name} must be an int bigger than {min_bound}, "
f"while we got {value}"))
if max_bound is not None:
- jit_error2(jnp.any(value > max_bound),
- ValueError(f"{name} must be an int smaller than {max_bound}, "
+ jit_error_checking_no_args(jnp.any(value > max_bound),
+ ValueError(f"{name} must be an int smaller than {max_bound}, "
f"while we got {value}"))
return value
@@ -596,7 +596,7 @@ def jit_error(pred, err_fun, err_arg=None):
Parameters
----------
- pred: bool
+ pred: bool, Array
The boolean prediction.
err_fun: callable
The error function, which raise errors.
@@ -610,7 +610,7 @@ def jit_error(pred, err_fun, err_arg=None):
jit_error_checking = jit_error
-def jit_error2(pred: bool, err: Exception):
+def jit_error_checking_no_args(pred: bool, err: Exception):
"""Check errors in a jit function.
Parameters
diff --git a/brainpy/dnn/others.py b/brainpy/dnn/others.py
index 7bd47b928..717dff569 100644
--- a/brainpy/dnn/others.py
+++ b/brainpy/dnn/others.py
@@ -9,5 +9,6 @@
from brainpy._src.dnn.function import (
Activation,
Flatten,
+ Unflatten,
FunAsLayer,
)
diff --git a/brainpy/dyn/projections.py b/brainpy/dyn/projections.py
index b2f4c5304..23e1a7485 100644
--- a/brainpy/dyn/projections.py
+++ b/brainpy/dyn/projections.py
@@ -1,24 +1,24 @@
-
-from brainpy._src.dyn.projections.aligns import (
- VanillaProj,
- ProjAlignPostMg1,
- ProjAlignPostMg2,
- ProjAlignPost1,
- ProjAlignPost2,
- ProjAlignPreMg1,
- ProjAlignPreMg2,
- ProjAlignPre1,
- ProjAlignPre2,
+from brainpy._src.dyn.projections.vanilla import VanillaProj
+from brainpy._src.dyn.projections.delta import (
+ HalfProjDelta,
+ FullProjDelta,
+)
+from brainpy._src.dyn.projections.align_post import (
+ HalfProjAlignPostMg,
+ FullProjAlignPostMg,
+ HalfProjAlignPost,
+ FullProjAlignPost,
+)
+from brainpy._src.dyn.projections.align_pre import (
+ FullProjAlignPreSDMg,
+ FullProjAlignPreDSMg,
+ FullProjAlignPreSD,
+ FullProjAlignPreDS,
)
-
from brainpy._src.dyn.projections.conn import (
SynConn as SynConn,
)
-
-from brainpy._src.dyn.projections.others import (
- PoissonInput as PoissonInput,
-)
-
from brainpy._src.dyn.projections.inputs import (
InputVar,
+ PoissonInput,
)
diff --git a/brainpy/dyn/synapses.py b/brainpy/dyn/synapses.py
index 68be31944..9a097be1a 100644
--- a/brainpy/dyn/synapses.py
+++ b/brainpy/dyn/synapses.py
@@ -1,6 +1,5 @@
from brainpy._src.dyn.synapses.abstract_models import (
- Delta,
Expon,
Alpha,
DualExpon,
diff --git a/brainpy/initialize.py b/brainpy/initialize.py
index d2e946527..0c737bc0b 100644
--- a/brainpy/initialize.py
+++ b/brainpy/initialize.py
@@ -22,6 +22,7 @@
from brainpy._src.initialize.random_inits import (
Normal as Normal,
Uniform as Uniform,
+ TruncatedNormal as TruncatedNormal,
VarianceScaling as VarianceScaling,
KaimingUniform as KaimingUniform,
KaimingNormal as KaimingNormal,
diff --git a/brainpy/losses.py b/brainpy/losses.py
index bf5177b74..f2506742c 100644
--- a/brainpy/losses.py
+++ b/brainpy/losses.py
@@ -18,6 +18,7 @@
log_cosh_loss as log_cosh_loss,
ctc_loss_with_forward_probs as ctc_loss_with_forward_probs,
ctc_loss as ctc_loss,
+ multi_margin_loss as multi_margin_loss,
)
from brainpy._src.losses.comparison import (
diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py
index e24d30ae0..02f671345 100644
--- a/brainpy/math/__init__.py
+++ b/brainpy/math/__init__.py
@@ -8,6 +8,7 @@
from .compat_numpy import *
from .compat_tensorflow import *
from .compat_pytorch import *
+from .einops import *
# functions
from .activations import *
@@ -32,33 +33,15 @@
from . import linalg
from . import random
+# taichi operations
+from . import tifunc
+
# others
from . import sharding
import jax.numpy as jnp
from jax import config
-mode = NonBatchingMode()
-'''Default computation mode.'''
-
-membrane_scaling = IdScaling()
-'''Default membrane_scaling.'''
-
-dt = 0.1
-'''Default time step.'''
-
-bool_ = jnp.bool_
-'''Default bool data type.'''
-
-int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32
-'''Default integer data type.'''
-
-float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32
-'''Default float data type.'''
-
-complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64
-'''Default complex data type.'''
-
del jnp, config
from brainpy._src.math.surrogate._compt import (
@@ -68,6 +51,7 @@
spike_with_mg_grad as spike_with_mg_grad,
)
+from brainpy._src.math import defaults
from brainpy._src.deprecations import deprecation_getattr
__deprecations = {
"sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.",
@@ -114,5 +98,6 @@
"Use brainpy.math.event.info instead.",
event.info),
}
-__getattr__ = deprecation_getattr(__name__, __deprecations)
-del deprecation_getattr
+
+__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults)
+del deprecation_getattr, defaults
diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py
index f522b6ab7..3b0c3f517 100644
--- a/brainpy/math/compat_pytorch.py
+++ b/brainpy/math/compat_pytorch.py
@@ -3,6 +3,7 @@
Tensor as Tensor,
flatten as flatten,
+ unflatten as unflatten,
cat as cat,
unsqueeze as unsqueeze,
abs as abs,
@@ -11,7 +12,7 @@
arccos as arccos,
acosh as acosh,
arccosh as arccosh,
- add as add,
+ # add as add,
addcdiv as addcdiv,
addcmul as addcmul,
angle as angle,
diff --git a/brainpy/math/einops.py b/brainpy/math/einops.py
new file mode 100644
index 000000000..5dcb4ce67
--- /dev/null
+++ b/brainpy/math/einops.py
@@ -0,0 +1,6 @@
+from brainpy._src.math.einops import (
+ ein_repeat as ein_repeat,
+ ein_shape as ein_shape,
+ ein_reduce as ein_reduce,
+ ein_rearrange as ein_rearrange,
+)
diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py
index a283cc921..d654a0217 100644
--- a/brainpy/math/environment.py
+++ b/brainpy/math/environment.py
@@ -30,6 +30,7 @@
clear_buffer_memory as clear_buffer_memory,
enable_gpu_memory_preallocation as enable_gpu_memory_preallocation,
disable_gpu_memory_preallocation as disable_gpu_memory_preallocation,
+ gpu_memory_preallocation as gpu_memory_preallocation,
ditype as ditype,
dftype as dftype,
)
diff --git a/brainpy/math/interoperability.py b/brainpy/math/interoperability.py
index 9bf4aee80..6956f9ba2 100644
--- a/brainpy/math/interoperability.py
+++ b/brainpy/math/interoperability.py
@@ -6,5 +6,7 @@
as_ndarray as as_ndarray,
as_numpy as as_numpy,
as_variable as as_variable,
+ from_numpy as from_numpy,
+ is_bp_array as is_bp_array,
)
diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py
index 0b012f869..7654731d8 100644
--- a/brainpy/math/oo_transform.py
+++ b/brainpy/math/oo_transform.py
@@ -25,6 +25,7 @@
from brainpy._src.math.object_transform.autograd import (
grad as grad,
vector_grad as vector_grad,
+ functional_vector_grad as functional_vector_grad,
jacobian as jacobian,
jacrev as jacrev,
jacfwd as jacfwd,
@@ -39,6 +40,7 @@
ifelse as ifelse,
for_loop as for_loop,
while_loop as while_loop,
+ scan as scan,
)
@@ -57,3 +59,7 @@
eval_shape as eval_shape,
)
+from brainpy._src.math.object_transform.variables import (
+ VariableStack as VariableStack,
+)
+
diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py
index b30ce4414..a48268ef4 100644
--- a/brainpy/math/op_register.py
+++ b/brainpy/math/op_register.py
@@ -4,10 +4,11 @@
from brainpy._src.math.op_register import (
CustomOpByNumba,
compile_cpu_signature_with_numba,
+ clean_caches,
+ check_kernels_count,
)
-
from brainpy._src.math.op_register.base import XLACustomOp
-
+from brainpy._src.math.op_register.ad_support import defjvp
diff --git a/brainpy/math/others.py b/brainpy/math/others.py
index 23d9b0816..9b9d7b368 100644
--- a/brainpy/math/others.py
+++ b/brainpy/math/others.py
@@ -4,6 +4,7 @@
shared_args_over_time as shared_args_over_time,
remove_diag as remove_diag,
clip_by_norm as clip_by_norm,
+ exprel as exprel,
)
from brainpy._src.math.object_transform.naming import (
diff --git a/brainpy/math/random.py b/brainpy/math/random.py
index dde1f4832..922362d60 100644
--- a/brainpy/math/random.py
+++ b/brainpy/math/random.py
@@ -70,5 +70,4 @@
rand_like as rand_like,
randint_like as randint_like,
randn_like as randn_like,
-
)
diff --git a/brainpy/math/surrogate.py b/brainpy/math/surrogate.py
index 3f3daa2b7..0121bddec 100644
--- a/brainpy/math/surrogate.py
+++ b/brainpy/math/surrogate.py
@@ -1,11 +1,8 @@
# -*- coding: utf-8 -*-
-from brainpy._src.math.surrogate.base import (
- Surrogate
-)
-from brainpy._src.math.surrogate._one_input import (
+from brainpy._src.math.surrogate._one_input_new import (
Sigmoid,
sigmoid as sigmoid,
diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py
new file mode 100644
index 000000000..63f3cbe45
--- /dev/null
+++ b/brainpy/math/tifunc.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+
+from brainpy._src.math.tifunc import (
+ taichi_lcg_rand,
+
+ # warp reduction primitives
+ warp_reduce_sum,
+
+ # random number generator
+ lfsr88_key,
+ lfsr88_next_key,
+ lfsr88_normal,
+ lfsr88_randn,
+ lfsr88_random_integers,
+ lfsr88_randint,
+ lfsr88_uniform,
+ lfsr88_rand,
+ lfsr113_key,
+ lfsr113_next_key,
+ lfsr113_normal,
+ lfsr113_randn,
+ lfsr113_random_integers,
+ lfsr113_randint,
+ lfsr113_uniform,
+ lfsr113_rand
+)
diff --git a/brainpy/tools.py b/brainpy/tools.py
index 0f3a4c0ef..233269dc5 100644
--- a/brainpy/tools.py
+++ b/brainpy/tools.py
@@ -45,4 +45,9 @@
)
+from brainpy._src.tools.functions import (
+ compose as compose,
+ pipe as pipe,
+)
+
diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst
index 5c8cba0fd..0b78315ab 100644
--- a/docs/advanced_tutorials.rst
+++ b/docs/advanced_tutorials.rst
@@ -3,13 +3,52 @@ Advanced Tutorials
This section contains tutorials that illustrate more advanced features of BrainPy.
+Advanced Math
+-------------
.. toctree::
- :maxdepth: 2
+ :maxdepth: 1
+
+ tutorial_advanced/compilation.ipynb
+ tutorial_advanced/differentiation.ipynb
+
+
+Interoperation
+--------------
+
+.. toctree::
+ :maxdepth: 1
+
+ tutorial_advanced/integrate_flax_into_brainpy.ipynb
+ tutorial_advanced/integrate_bp_lif_into_flax.ipynb
+ tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb
+
+
+Brain Dynamics Dedicated Operators
+----------------------------------
+
+.. toctree::
+ :maxdepth: 1
+
+ tutorial_advanced/operator_custom_with_numba.ipynb
+ tutorial_advanced/operator_custom_with_taichi.ipynb
+
+
+Developer Guides
+----------------
+
+.. toctree::
+ :maxdepth: 1
+
+ tutorial_advanced/contributing.md
+
+
+Others
+------
+
+.. toctree::
+ :maxdepth: 1
+
+ tutorial_advanced/advanced_lowdim_analysis.ipynb
- tutorial_advanced/1_advanced_math.rst
- tutorial_advanced/2_interoperation.rst
- tutorial_advanced/3_dedicated_operators.rst
- tutorial_advanced/4_developer_guides.rst
- tutorial_advanced/5_others.rst
diff --git a/docs/apis/brainpy.dyn.projections.rst b/docs/apis/brainpy.dyn.projections.rst
index c1f8c1070..5549e6394 100644
--- a/docs/apis/brainpy.dyn.projections.rst
+++ b/docs/apis/brainpy.dyn.projections.rst
@@ -6,27 +6,23 @@ Synaptic Projections
-Reduced Projections
--------------------
+Projections for Align-Post Reduction
+------------------------------------
.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst
- ProjAlignPostMg1
- ProjAlignPostMg2
- ProjAlignPost1
- ProjAlignPost2
- ProjAlignPreMg1
- ProjAlignPreMg2
- ProjAlignPre1
- ProjAlignPre2
+ HalfProjAlignPostMg
+ FullProjAlignPostMg
+ HalfProjAlignPost
+ FullProjAlignPost
-Projections
------------
+Projections for Align-Pre Reduction
+------------------------------------
.. autosummary::
:toctree: generated/
@@ -34,7 +30,23 @@ Projections
:template: classtemplate.rst
VanillaProj
- SynConn
+ FullProjAlignPreSDMg
+ FullProjAlignPreDSMg
+ FullProjAlignPreSD
+ FullProjAlignPreDS
+
+
+
+Projections for Delta synapses
+------------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ HalfProjDelta
+ FullProjDelta
@@ -46,6 +58,18 @@ Inputs
:nosignatures:
:template: classtemplate.rst
-
PoissonInput
InputVar
+
+
+
+Others
+------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ SynConn
+
diff --git a/docs/apis/brainpy.dyn.synapses.rst b/docs/apis/brainpy.dyn.synapses.rst
index ea4313c69..bea61ab87 100644
--- a/docs/apis/brainpy.dyn.synapses.rst
+++ b/docs/apis/brainpy.dyn.synapses.rst
@@ -42,7 +42,6 @@ Phenomenological synapse models
:nosignatures:
:template: classtemplate.rst
- Delta
Expon
Alpha
DualExpon
diff --git a/docs/apis/brainpy.math.defaults.rst b/docs/apis/brainpy.math.defaults.rst
new file mode 100644
index 000000000..515391dcf
--- /dev/null
+++ b/docs/apis/brainpy.math.defaults.rst
@@ -0,0 +1,22 @@
+
+Default Math Parameters
+=======================
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+
+ mode
+ membrane_scaling
+ dt
+ bool_
+ int_
+ ti_int
+ float_
+ ti_float
+ complex_
+
+
diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst
index 5ee94c615..9ed9cf46a 100644
--- a/docs/apis/brainpy.math.oo_transform.rst
+++ b/docs/apis/brainpy.math.oo_transform.rst
@@ -60,6 +60,7 @@ Object-oriented Transformations
ifelse
for_loop
while_loop
+ scan
jit
cls_jit
to_object
@@ -76,4 +77,5 @@ Helpers for Object-oriented Transformations
:template: classtemplate.rst
eval_shape
+ VariableStack
diff --git a/docs/apis/brainpy.math.op_register.rst b/docs/apis/brainpy.math.op_register.rst
index 7010b64eb..a50b4d300 100644
--- a/docs/apis/brainpy.math.op_register.rst
+++ b/docs/apis/brainpy.math.op_register.rst
@@ -6,6 +6,22 @@ Operator Registration
:depth: 1
+
+General Operator Customization Interface
+----------------------------------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ XLACustomOp
+
+
+
CPU Operator Customization with Numba
-------------------------------------
diff --git a/docs/apis/brainpy.math.random.rst b/docs/apis/brainpy.math.random.rst
index e52a3450b..5a0af2fa1 100644
--- a/docs/apis/brainpy.math.random.rst
+++ b/docs/apis/brainpy.math.random.rst
@@ -4,10 +4,15 @@
.. currentmodule:: brainpy.math.random
.. automodule:: brainpy.math.random
+
+
+Random Sampling Functions
+-------------------------
+
+
.. autosummary::
:toctree: generated/
:nosignatures:
- :template: classtemplate.rst
seed
split_key
@@ -70,6 +75,17 @@
rand_like
randint_like
randn_like
+
+
+Random Generator
+-------------------------
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
RandomState
Generator
DEFAULT
diff --git a/docs/apis/dnn.rst b/docs/apis/dnn.rst
index eea54ef24..c36a38186 100644
--- a/docs/apis/dnn.rst
+++ b/docs/apis/dnn.rst
@@ -17,8 +17,6 @@ Non-linear Activations
:template: classtemplate.rst
Activation
- Flatten
- FunAsLayer
Threshold
ReLU
RReLU
@@ -150,18 +148,16 @@ Interoperation with Flax
ToFlax
-Other Layers
-------------
+Utility Layers
+--------------
.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst
- Layer
Dropout
- Activation
Flatten
+ Unflatten
FunAsLayer
-
diff --git a/docs/apis/initialize.rst b/docs/apis/initialize.rst
index f516aa5b5..bd8c7031b 100644
--- a/docs/apis/initialize.rst
+++ b/docs/apis/initialize.rst
@@ -45,6 +45,7 @@ Random Initializers
Normal
Uniform
+ TruncatedNormal
VarianceScaling
KaimingUniform
KaimingNormal
diff --git a/docs/apis/losses.rst b/docs/apis/losses.rst
index 8f50c487f..4f4a3d167 100644
--- a/docs/apis/losses.rst
+++ b/docs/apis/losses.rst
@@ -33,6 +33,14 @@ Comparison
log_cosh_loss
ctc_loss_with_forward_probs
ctc_loss
+ multi_margin_loss
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
CrossEntropyLoss
NLLLoss
L1Loss
diff --git a/docs/apis/math.rst b/docs/apis/math.rst
index e3f0b765a..f4b778aba 100644
--- a/docs/apis/math.rst
+++ b/docs/apis/math.rst
@@ -24,6 +24,7 @@ dynamics programming. For more information and usage examples, please refer to t
:maxdepth: 1
brainpy.math.rst
+ brainpy.math.defaults.rst
brainpy.math.delayvars.rst
brainpy.math.oo_transform.rst
brainpy.math.pre_syn_post.rst
diff --git a/docs/core_concept/brainpy_dynamical_system.ipynb b/docs/core_concept/brainpy_dynamical_system.ipynb
index b8151486d..ba061243d 100644
--- a/docs/core_concept/brainpy_dynamical_system.ipynb
+++ b/docs/core_concept/brainpy_dynamical_system.ipynb
@@ -4,7 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Concept 2: Dynamical System"
+ "# Concept 2: Dynamical System\n",
+ "\n",
+ "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/core_concept/brainpy_dynamical_system.ipynb)\n",
+ "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/core_concept/brainpy_dynamical_system.ipynb)"
]
},
{
@@ -425,7 +428,7 @@
" currents = bm.random.rand(200, 10, 100)\n",
"\n",
" # run the model\n",
- " net2.reset_state(batch_size=10)\n",
+ " net2.reset(10)\n",
" out = bm.for_loop(run_net2, (times, currents))\n",
"\n",
"out.shape"
@@ -459,7 +462,7 @@
}
],
"source": [
- "net2.reset_state(batch_size=10)\n",
+ "net2.reset(10)\n",
"looper = bp.LoopOverTime(net2)\n",
"out = looper(currents)\n",
"out.shape"
diff --git a/docs/core_concept/brainpy_transform_concept.ipynb b/docs/core_concept/brainpy_transform_concept.ipynb
index 5c2707567..3767a0aa4 100644
--- a/docs/core_concept/brainpy_transform_concept.ipynb
+++ b/docs/core_concept/brainpy_transform_concept.ipynb
@@ -9,7 +9,10 @@
}
},
"source": [
- "# Concept 1: Object-oriented Transformation"
+ "# Concept 1: Object-oriented Transformation\n",
+ "\n",
+ "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/docs/core_concept/brainpy_transform_concept.ipynb)\n",
+ "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/core_concept/brainpy_transform_concept.ipynb)"
]
},
{
diff --git a/docs/index.rst b/docs/index.rst
index 1853bc97a..732b27aa2 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -2,90 +2,11 @@ BrainPy documentation
=====================
`BrainPy`_ is a highly flexible and extensible framework targeting on the
-general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, BrainPy supports:
+general-purpose Brain Dynamics Programming (BDP).
.. _BrainPy: https://github.com/brainpy/BrainPy
-Features
-^^^^^^^^^
-
-.. grid::
-
- .. grid-item::
- :columns: 12 12 12 6
-
- .. card:: OO Transformations
- :class-card: sd-border-0
- :shadow: none
- :class-title: sd-fs-5
-
- .. div:: sd-font-normal
-
- BrainPy supports object-oriented transformations, including
- JIT compilation, Autograd.
-
- .. grid-item::
- :columns: 12 12 12 6
-
- .. card:: Numerical Integrators
- :class-card: sd-border-0
- :shadow: none
- :class-title: sd-fs-5
-
- .. div:: sd-font-normal
-
- BrainPy provides various numerical integration methods for ODEs, SDEs, DDEs, FDEs, etc.
-
- .. grid-item::
- :columns: 12 12 12 6
-
- .. card:: Model Building
- :class-card: sd-border-0
- :shadow: none
- :class-title: sd-fs-5
-
- .. div:: sd-font-normal
-
- BrainPy provides a modular and composable programming interface for building dynamics.
-
- .. grid-item::
- :columns: 12 12 12 6
-
- .. card:: Model Simulation
- :class-card: sd-border-0
- :shadow: none
- :class-title: sd-fs-5
-
- .. div:: sd-font-normal
-
- BrainPy supports dynamics simulation for various brain objects with parallel supports.
-
-
- .. grid-item::
- :columns: 12 12 12 6
-
- .. card:: Model Training
- :class-card: sd-border-0
- :shadow: none
- :class-title: sd-fs-5
-
- .. div:: sd-font-normal
-
- BrainPy supports dynamics training with various machine learning algorithms, like FORCE learning, ridge regression, back-propagation, etc.
-
- .. grid-item::
- :columns: 12 12 12 6
-
- .. card:: Model Analysis
- :class-card: sd-border-0
- :shadow: none
- :class-title: sd-fs-5
-
- .. div:: sd-font-normal
-
- BrainPy supports dynamics analysis for low- and high-dimensional systems, including phase plane, bifurcation, linearization, and fixed/slow point analysis.
-
----
Installation
@@ -96,24 +17,18 @@ Installation
.. code-block:: bash
- pip install -U "jax[cpu]"
-
pip install -U brainpy brainpylib # windows, linux, macos
.. tab-item:: GPU (CUDA-11x)
.. code-block:: bash
- pip install -U "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-
pip install -U brainpy brainpylib-cu11x # only on linux
.. tab-item:: GPU (CUDA-12x)
.. code-block:: bash
- pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-
pip install -U brainpy brainpylib-cu12x # only on linux
For more information about supported accelerators and platforms, and for other installation details, please see `installation `_ section.
diff --git a/docs/quickstart/analysis.ipynb b/docs/quickstart/analysis.ipynb
index 02515a1aa..d8b62de11 100644
--- a/docs/quickstart/analysis.ipynb
+++ b/docs/quickstart/analysis.ipynb
@@ -5,7 +5,10 @@
"id": "ae1512d8",
"metadata": {},
"source": [
- "# Analyzing a Brain Dynamics Model"
+ "# Analyzing a Brain Dynamics Model\n",
+ "\n",
+ "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/quickstart/analysis.ipynb)\n",
+ "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/quickstart/analysis.ipynb)"
]
},
{
@@ -54,10 +57,19 @@
{
"cell_type": "code",
"execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-07-21T08:53:38.204162500Z",
+ "start_time": "2023-07-21T08:53:38.185849800Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
"data": {
- "text/plain": "'2.4.3'"
+ "text/plain": [
+ "'2.4.3'"
+ ]
},
"execution_count": 2,
"metadata": {},
@@ -66,14 +78,7 @@
],
"source": [
"bp.__version__"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2023-07-21T08:53:38.204162500Z",
- "start_time": "2023-07-21T08:53:38.185849800Z"
- }
- }
+ ]
},
{
"cell_type": "markdown",
@@ -93,6 +98,9 @@
},
{
"cell_type": "markdown",
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Let's try to analyze how the external input influences the dynamics of the Exponential Integrate-and-Fire (ExpIF) model. The ExpIF model is a one-variable neuron model whose dynamics is defined by:\n",
"\n",
@@ -100,10 +108,7 @@
"\\tau {\\dot {V}}= - (V - V_\\mathrm{rest}) + \\Delta_T \\exp(\\frac{V - V_T}{\\Delta_T}) + RI \\\\\n",
"\\mathrm{if}\\, \\, V > \\theta, \\quad V \\gets V_\\mathrm{reset}\n",
"$$"
- ],
- "metadata": {
- "collapsed": false
- }
+ ]
},
{
"cell_type": "markdown",
@@ -149,7 +154,9 @@
"outputs": [
{
"data": {
- "text/plain": "(-65.0, -59.9, 1.0, 10.0)"
+ "text/plain": [
+ "(-65.0, -59.9, 1.0, 10.0)"
+ ]
},
"execution_count": 4,
"metadata": {},
@@ -188,8 +195,10 @@
},
{
"data": {
- "text/plain": "