diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 01bdd87f1..f59927666 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -51,7 +51,7 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
- pytest _src/
+ export IS_GITHUB_ACTIONS=1 && pytest _src/
test_macos:
@@ -79,10 +79,12 @@ jobs:
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
+ pip install jax==0.4.30
+ pip install jaxlib==0.4.30
- name: Test with pytest
run: |
cd brainpy
- pytest _src/
+ export IS_GITHUB_ACTIONS=1 && pytest _src/
test_windows:
@@ -113,4 +115,4 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
- pytest _src/ -p no:faulthandler
+ set IS_GITHUB_ACTIONS=1 && pytest _src/
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 1ab25592b..d277d9a4d 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -41,7 +41,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Docker Build & Push (version tag)
- uses: docker/build-push-action@v5
+ uses: docker/build-push-action@v6
with:
context: ${{ matrix.context }}
tags: ${{ matrix.base }}:${{ env.DOCKER_TAG_NAME }}
@@ -51,7 +51,7 @@ jobs:
- name: Docker Build & Push (latest tag)
if: |
(github.event_name == 'release' && ! github.event.release.prerelease)
- uses: docker/build-push-action@v5
+ uses: docker/build-push-action@v6
with:
context: ${{ matrix.context }}
tags: ${{ matrix.base }}:latest
diff --git a/README.md b/README.md
index a7fe0b721..5e2d8fa1f 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,21 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu
- **Source on OpenI**: https://git.openi.org.cn/OpenI/BrainPy
+---
+**NOTE**
+
+Starting from our experimental BrainPy package, a better and mature ecosystem for brain dynamics programming is emerging.
+Please see the [Brain Dynamics Programming Ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/) for more details.
+
+
+If you are heavily using BrainPy, please consider using [brainstate](https://brainstate.readthedocs.io) for a more stable, efficient, concise, and powerful experience.
+
+
+[brainstate](https://github.com/chaobrain/brainstate) is and will be active maintained and developed by our team. We highly recommend transferring your code to [brainstate](https://brainstate.readthedocs.io) for a better performance.
+
+---
+
+
## Installation
@@ -50,8 +65,12 @@ We provide a Binder environment for BrainPy. You can use the following button to
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main)
+
+
## Ecosystem
+
+- **[Brain Dynamics Programming Ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/)**: An emerging and mature ecosystem for brain dynamics programming.
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
@@ -67,21 +86,3 @@ Our team is committed to the long-term maintenance and development of the projec
If you are using ``brainpy``, please consider citing [the corresponding papers](https://brainpy.readthedocs.io/en/latest/tutorial_FAQs/citing_and_publication.html).
-
-
-## Ongoing development plans
-
-We highlight the key features and functionalities that are currently under active development.
-
-We also welcome your contributions
-(see [Contributing to BrainPy](https://brainpy.readthedocs.io/en/latest/tutorial_advanced/contributing.html)).
-
-- [x] model and data parallelization on multiple devices for dense connection models
-- [ ] model parallelization on multiple devices for sparse spiking network models
-- [ ] data parallelization on multiple devices for sparse spiking network models
-- [ ] pipeline parallelization on multiple devices for sparse spiking network models
-- [ ] multi-compartment modeling
-- [ ] measurements, analysis, and visualization methods for large-scale spiking data
-- [ ] Online learning methods for large-scale spiking network models
-- [ ] Classical plasticity rules for large-scale spiking network models
-
diff --git a/brainpy-changelog.md b/brainpy-changelog.md
index c949b7010..77e112b5a 100644
--- a/brainpy-changelog.md
+++ b/brainpy-changelog.md
@@ -2,6 +2,80 @@
## brainpy>2.3.x
+### Version 2.6.1
+#### Breaking Changes
+- Fixing compatibility issues between `numpy` and `jax`
+
+#### What's Changed
+* [doc] Add Chinese version of `operator_custom_with_cupy.ipynb` and Rename it's title by @Routhleck in https://github.com/brainpy/BrainPy/pull/659
+* Fix "amsgrad" is used before being defined when initializing the AdamW optimizer by @CloudyDory in https://github.com/brainpy/BrainPy/pull/660
+* fix issue #661 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/662
+* fix flax RNN interoperation, fix #663 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/665
+* [fix] Replace jax.experimental.host_callback with jax.pure_callback by @Routhleck in https://github.com/brainpy/BrainPy/pull/670
+* [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24 by @Routhleck in https://github.com/brainpy/BrainPy/pull/669
+* [math] Fix `CustomOpByNumba` on `multiple_results=True` by @Routhleck in https://github.com/brainpy/BrainPy/pull/671
+* [math] Implementing event-driven sparse matrix @ matrix operators by @Routhleck in https://github.com/brainpy/BrainPy/pull/613
+* [math] Add getting JIT connect matrix method for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/672
+* [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/673
+* support `Integrator.to_math_expr()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/674
+* [bug] Replace `collections.Iterable` with `collections.abc.Iterable` by @Routhleck in https://github.com/brainpy/BrainPy/pull/677
+* Fix surrogate gradient function and numpy 2.0 compatibility by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/679
+* :arrow_up: Bump docker/build-push-action from 5 to 6 by @dependabot in https://github.com/brainpy/BrainPy/pull/678
+* fix the incorrect verbose of `clear_name_cache()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/681
+* [bug] Fix prograss bar is not displayed and updated as expected by @Routhleck in https://github.com/brainpy/BrainPy/pull/683
+* Fix autograd by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/687
+
+
+**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.6.0...V2.6.1
+
+### Version 2.6.0
+
+#### New Features
+
+This release provides several new features, including:
+
+- ``MLIR`` registered operator customization interface in ``brainpy.math.XLACustomOp``.
+- Operator customization with CuPy JIT interface.
+- Bug fixes.
+
+
+
+#### What's Changed
+* [doc] Fix the wrong path of more examples of `operator customized with taichi.ipynb` by @Routhleck in https://github.com/brainpy/BrainPy/pull/612
+* [docs] Add colab link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/614
+* Update requirements-doc.txt to fix doc building temporally by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/617
+* [math] Rebase operator customization using MLIR registration interface by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/618
+* [docs] Add kaggle link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/619
+* update requirements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/620
+* require `brainpylib>=0.2.6` for `jax>=0.4.24` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/622
+* [tools] add `brainpy.tools.compose` and `brainpy.tools.pipe` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/624
+* doc hierarchy update by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/630
+* Standardizing and generalizing object-oriented transformations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/628
+* fix #626 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/631
+* Fix delayvar not correct in concat mode by @CloudyDory in https://github.com/brainpy/BrainPy/pull/632
+* [dependency] remove hard dependency of `taichi` and `numba` by @Routhleck in https://github.com/brainpy/BrainPy/pull/635
+* `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/639
+* add `brainpy.math.surrogate..Surrogate` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/638
+* Enable brainpy object as pytree so that it can be applied with ``jax.jit`` etc. directly by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/625
+* Fix ci by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/640
+* Clean taichi AOT caches by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/643
+* [ci] Fix windows pytest fatal exception by @Routhleck in https://github.com/brainpy/BrainPy/pull/644
+* [math] Support more than 8 parameters of taichi gpu custom operator definition by @Routhleck in https://github.com/brainpy/BrainPy/pull/642
+* Doc for ``brainpylib>=0.3.0`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/645
+* Find back updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/646
+* Update installation instruction by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/651
+* Fix delay bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/650
+* update doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/652
+* [math] Add new customize operators with `cupy` by @Routhleck in https://github.com/brainpy/BrainPy/pull/653
+* [math] Fix taichi custom operator on gpu backend by @Routhleck in https://github.com/brainpy/BrainPy/pull/655
+* update cupy operator custom doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/656
+* version 2.6.0 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/657
+* Upgrade CI by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/658
+
+## New Contributors
+* @CloudyDory made their first contribution in https://github.com/brainpy/BrainPy/pull/632
+
+**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.5.0...V2.6.0
### Version 2.5.0
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 837efaf1d..e69837fda 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -153,3 +153,8 @@
del deprecation_getattr2
+# jax config
+import os
+os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
+import jax
+jax.config.update('jax_cpu_enable_async_dispatch', False)
diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py
index ee91b55a5..9f31946f2 100644
--- a/brainpy/_src/analysis/highdim/slow_points.py
+++ b/brainpy/_src/analysis/highdim/slow_points.py
@@ -329,7 +329,7 @@ def find_fps_with_gd_method(
"""
# optimization settings
if optimizer is None:
- optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
+ optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py
index 1e1060625..60a394ce1 100644
--- a/brainpy/_src/dependency_check.py
+++ b/brainpy/_src/dependency_check.py
@@ -1,185 +1,66 @@
+import importlib.util
import os
import sys
-from jax.lib import xla_client
-
__all__ = [
- 'import_taichi',
- 'raise_taichi_not_found',
- 'import_numba',
- 'raise_numba_not_found',
- 'import_cupy',
- 'import_cupy_jit',
- 'raise_cupy_not_found',
- 'import_brainpylib_cpu_ops',
- 'import_brainpylib_gpu_ops',
+ 'import_taichi',
+ 'import_braintaichi',
+ 'raise_braintaichi_not_found',
]
-_minimal_brainpylib_version = '0.2.6'
-_minimal_taichi_version = (1, 7, 0)
-
-numba = None
taichi = None
-cupy = None
-cupy_jit = None
-brainpylib_cpu_ops = None
-brainpylib_gpu_ops = None
+braintaichi = None
+braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n'
+ '> pip install braintaichi -U')
-taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
- f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
- '> pip install taichi==1.7.0')
-numba_install_info = ('We need numba. Please install numba by pip . \n'
- '> pip install numba')
-cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
- 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
- 'For CUDA v12.x > pip install cupy-cuda12x\n')
os.environ["TI_LOG_LEVEL"] = "error"
def import_taichi(error_if_not_found=True):
- """Internal API to import taichi.
-
- If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
- otherwise it will return None.
- """
- global taichi
- if taichi is None:
- with open(os.devnull, 'w') as devnull:
- old_stdout = sys.stdout
- sys.stdout = devnull
- try:
- import taichi as taichi # noqa
- except ModuleNotFoundError:
- if error_if_not_found:
- raise raise_taichi_not_found()
- finally:
- sys.stdout = old_stdout
-
- if taichi is None:
- return None
- if taichi.__version__ != _minimal_taichi_version:
- raise RuntimeError(taichi_install_info)
- return taichi
-
-
-def raise_taichi_not_found(*args, **kwargs):
- raise ModuleNotFoundError(taichi_install_info)
-
-
-def import_numba(error_if_not_found=True):
- """
- Internal API to import numba.
-
- If numba is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
- otherwise it will return None.
- """
- global numba
- if numba is None:
- try:
- import numba as numba
- except ModuleNotFoundError:
- if error_if_not_found:
- raise_numba_not_found()
- else:
- return None
- return numba
-
-
-def raise_numba_not_found():
- raise ModuleNotFoundError(numba_install_info)
-
-
-def import_cupy(error_if_not_found=True):
- """
- Internal API to import cupy.
-
- If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
- otherwise it will return None.
- """
- global cupy
- if cupy is None:
- try:
- import cupy as cupy
- except ModuleNotFoundError:
- if error_if_not_found:
- raise_cupy_not_found()
- else:
- return None
- return cupy
-
-
-def import_cupy_jit(error_if_not_found=True):
- """
- Internal API to import cupy.
-
- If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
- otherwise it will return None.
- """
- global cupy_jit
- if cupy_jit is None:
- try:
- from cupyx import jit as cupy_jit
- except ModuleNotFoundError:
- if error_if_not_found:
- raise_cupy_not_found()
- else:
- return None
- return cupy_jit
-
-
-def raise_cupy_not_found():
- raise ModuleNotFoundError(cupy_install_info)
-
-
-def is_brainpylib_gpu_installed():
- return False if brainpylib_gpu_ops is None else True
-
-
-def import_brainpylib_cpu_ops():
- """
- Internal API to import brainpylib cpu_ops.
- """
- global brainpylib_cpu_ops
- if brainpylib_cpu_ops is None:
- try:
- from brainpylib import cpu_ops as brainpylib_cpu_ops
-
- for _name, _value in brainpylib_cpu_ops.registrations().items():
- xla_client.register_custom_call_target(_name, _value, platform="cpu")
-
- import brainpylib
- if brainpylib.__version__ < _minimal_brainpylib_version:
- raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
- if hasattr(brainpylib, 'check_brainpy_version'):
- brainpylib.check_brainpy_version()
-
- except ImportError:
- raise ImportError('Please install brainpylib. \n'
- 'See https://brainpy.readthedocs.io for installation instructions.')
-
- return brainpylib_cpu_ops
-
-
-def import_brainpylib_gpu_ops():
- """
- Internal API to import brainpylib gpu_ops.
- """
- global brainpylib_gpu_ops
- if brainpylib_gpu_ops is None:
- try:
- from brainpylib import gpu_ops as brainpylib_gpu_ops
-
- for _name, _value in brainpylib_gpu_ops.registrations().items():
- xla_client.register_custom_call_target(_name, _value, platform="gpu")
-
- import brainpylib
- if brainpylib.__version__ < _minimal_brainpylib_version:
- raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
- if hasattr(brainpylib, 'check_brainpy_version'):
- brainpylib.check_brainpy_version()
-
- except ImportError:
- raise ImportError('Please install GPU version of brainpylib. \n'
- 'See https://brainpy.readthedocs.io for installation instructions.')
-
- return brainpylib_gpu_ops
+ """Internal API to import taichi.
+
+ If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
+ otherwise it will return None.
+ """
+ global taichi
+ if taichi is None:
+ if importlib.util.find_spec('taichi') is not None:
+ with open(os.devnull, 'w') as devnull:
+ old_stdout = sys.stdout
+ sys.stdout = devnull
+ try:
+ import taichi as taichi # noqa
+ except ModuleNotFoundError as e:
+ if error_if_not_found:
+ raise e
+ finally:
+ sys.stdout = old_stdout
+ else:
+ taichi = None
+
+ return taichi
+
+
+def import_braintaichi(error_if_not_found=True):
+ """Internal API to import braintaichi.
+
+ If braintaichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
+ otherwise it will return None.
+ """
+ global braintaichi
+ if braintaichi is None:
+ if importlib.util.find_spec('braintaichi') is not None:
+ try:
+ import braintaichi as braintaichi
+ except ModuleNotFoundError:
+ if error_if_not_found:
+ raise_braintaichi_not_found()
+ else:
+ braintaichi = None
+ else:
+ braintaichi = None
+ return braintaichi
+
+
+def raise_braintaichi_not_found():
+ raise ModuleNotFoundError(braintaichi_install_info)
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index c524fb0bf..06fa9413f 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
+import importlib.util
import numbers
from typing import Dict, Optional, Union, Callable
@@ -11,7 +12,7 @@
from brainpy import math as bm
from brainpy._src import connect, initialize as init
from brainpy._src.context import share
-from brainpy._src.dependency_check import import_taichi
+from brainpy._src.dependency_check import import_taichi, import_braintaichi
from brainpy._src.dnn.base import Layer
from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
from brainpy.check import is_initializer
@@ -20,1439 +21,1510 @@
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding
-ti = import_taichi(error_if_not_found=False)
+
+ti = import_taichi()
+bti = import_braintaichi()
__all__ = [
- 'Dense', 'Linear',
- 'Identity',
- 'AllToAll',
- 'OneToOne',
- 'MaskedLinear',
- 'CSRLinear', 'EventCSRLinear',
- 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear',
- 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear',
+ 'Dense', 'Linear',
+ 'Identity',
+ 'AllToAll',
+ 'OneToOne',
+ 'MaskedLinear',
+ 'CSRLinear', 'EventCSRLinear',
+ 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear',
+ 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear',
]
class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline):
- r"""A linear transformation applied over the last dimension of the input.
-
- Mathematically, this node can be defined as:
-
- .. math::
-
- y = x \cdot weight + b
-
- Parameters
- ----------
- num_in: int
- The number of the input feature. A positive integer.
- num_out: int
- The number of the output features. A positive integer.
- W_initializer: optional, Initializer
- The weight initialization.
- b_initializer: optional, Initializer
- The bias initialization.
- mode: Mode
- Enable training this node or not. (default True)
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
- b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super(Dense, self).__init__(mode=mode, name=name)
-
- # shape
- self.num_in = num_in
- self.num_out = num_out
- if num_in < 0:
- raise ValueError(f'Received an invalid value for `num_out`, expected '
- f'a positive integer. Received: num_in={num_in}')
- if num_out < 0:
- raise ValueError(f'Received an invalid value for `num_out`, expected '
- f'a positive integer. Received: num_out={num_out}')
-
- # weight initializer
- self.W_initializer = W_initializer
- self.bias_initializer = b_initializer
- is_initializer(W_initializer, 'weight_initializer')
- is_initializer(b_initializer, 'bias_initializer', allow_none=True)
-
- # parameter initialization
- W = parameter(self.W_initializer, (num_in, self.num_out))
- b = parameter(self.bias_initializer, (self.num_out,))
- if isinstance(self.mode, bm.TrainingMode):
- W = bm.TrainVar(W)
- b = None if (b is None) else bm.TrainVar(b)
- self.W = W
- self.b = b
-
- # fitting parameters
- self.online_fit_by = None # support online training
- self.offline_fit_by = None # support offline training
- self.fit_record = dict()
-
- def __repr__(self):
- return (f'{self.__class__.__name__}(name={self.name}, '
- f'num_in={self.num_in}, '
- f'num_out={self.num_out}, '
- f'mode={self.mode})')
-
- def update(self, x):
- x = bm.as_jax(x)
- res = x @ self.W
- if self.b is not None:
- res += self.b
-
- # online fitting data
- if share.load('fit', False) and self.online_fit_by is not None:
- self.fit_record['input'] = x
- self.fit_record['output'] = res
-
- # offline fitting data
- if share.load('fit', False) and self.offline_fit_by is not None:
- self.fit_record['input'] = x
- self.fit_record['output'] = res
- return res
-
- def online_init(self):
- if self.b is None:
- num_input = self.num_in
- else:
- num_input = self.num_in + 1
- self.online_fit_by.register_target(feature_in=num_input, identifier=self.name)
-
- def online_fit(self,
- target: ArrayType,
- fit_record: Dict[str, ArrayType]):
- if not isinstance(target, (bm.ndarray, jnp.ndarray)):
- raise MathError(f'"target" must be a tensor, but got {type(target)}')
- x = fit_record['input']
- y = fit_record['output']
- if x.ndim != 2:
- raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, '
- f'num_feature), but we got {x.shape}')
- if target.ndim != 2:
- raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, '
- f'num_feature), but we got {target.shape}')
- if x.shape[0] != target.shape[0]:
- raise ValueError(f'Batch size of the input and target data should be '
- f'the same, while we got {x.shape[0]} != {target.shape[0]}.')
- if target.shape[1] != y.shape[1]:
- raise MathError(f'The output dimension of output and target data should be '
- f'the same, while we got {target.shape[1]} != {y.shape[1]}')
-
- # data
- if self.b is not None:
- x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1)
-
- # fitting
- dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name)
-
- # assign trained weights
- if self.b is None:
- self.W += dW
- else:
- db, dW = jnp.split(dW, [1])
- self.b += db[0]
- self.W += dW
-
- def offline_fit(self,
- target: ArrayType,
- fit_record: Dict[str, ArrayType]):
- """The offline training interface for the Dense node."""
- # data checking
- if not isinstance(target, (bm.ndarray, jnp.ndarray)):
- raise MathError(f'"targets" must be a tensor, but got {type(target)}')
- xs = fit_record['input']
- ys = fit_record['output']
- if xs.ndim != 3:
- raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, '
- f'num_feature), but we got {xs.shape}')
- if target.ndim != 3:
- raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, '
- f'num_feature), but we got {target.shape}')
- if ys.shape != target.shape:
- raise ValueError(f'The shapes of output and target data should be '
- f'the same, while we got {ys.shape} != {target.shape}.')
- if xs.shape[0] != target.shape[0]:
- raise ValueError(f'Batch size of the input and target data should be '
- f'the same, while we got {xs.shape[0]} != {target.shape[0]}.')
- if xs.shape[1] != target.shape[1]:
- raise MathError(f'The time dimension of input and target data should be '
- f'the same, while we got {xs.shape[1]} != {target.shape[1]}')
-
- # get input and target training data
- if self.b is not None:
- xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input)
-
- # solve weights by offline training methods
- weights = self.offline_fit_by(target, xs, ys)
-
- # assign trained weights
- if self.b is None:
- self.W.value = weights
- else:
- bias, Wff = jnp.split(weights, [1])
- self.W.value = Wff
- self.b.value = bias[0]
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.W, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.W, bm.Variable):
- self.tracing_variable('W', self.W, self.W.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max)
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max)
+ r"""A linear transformation applied over the last dimension of the input.
+
+ Mathematically, this node can be defined as:
+
+ .. math::
+
+ y = x \cdot weight + b
+
+ Parameters
+ ----------
+ num_in: int
+ The number of the input feature. A positive integer.
+ num_out: int
+ The number of the output features. A positive integer.
+ W_initializer: optional, Initializer
+ The weight initialization.
+ b_initializer: optional, Initializer
+ The bias initialization.
+ mode: Mode
+ Enable training this node or not. (default True)
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
+ b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super(Dense, self).__init__(mode=mode, name=name)
+
+ # shape
+ self.num_in = num_in
+ self.num_out = num_out
+ if num_in < 0:
+ raise ValueError(f'Received an invalid value for `num_out`, expected '
+ f'a positive integer. Received: num_in={num_in}')
+ if num_out < 0:
+ raise ValueError(f'Received an invalid value for `num_out`, expected '
+ f'a positive integer. Received: num_out={num_out}')
+
+ # weight initializer
+ self.W_initializer = W_initializer
+ self.bias_initializer = b_initializer
+ is_initializer(W_initializer, 'weight_initializer')
+ is_initializer(b_initializer, 'bias_initializer', allow_none=True)
+
+ # parameter initialization
+ W = parameter(self.W_initializer, (num_in, self.num_out))
+ b = parameter(self.bias_initializer, (self.num_out,))
+ if isinstance(self.mode, bm.TrainingMode):
+ W = bm.TrainVar(W)
+ b = None if (b is None) else bm.TrainVar(b)
+ self.W = W
+ self.b = b
+
+ # fitting parameters
+ self.online_fit_by = None # support online training
+ self.offline_fit_by = None # support offline training
+ self.fit_record = dict()
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(name={self.name}, '
+ f'num_in={self.num_in}, '
+ f'num_out={self.num_out}, '
+ f'mode={self.mode})')
+
+ def update(self, x):
+ x = bm.as_jax(x)
+ res = x @ self.W
+ if self.b is not None:
+ res += self.b
+
+ # online fitting data
+ if share.load('fit', False) and self.online_fit_by is not None:
+ self.fit_record['input'] = x
+ self.fit_record['output'] = res
+
+ # offline fitting data
+ if share.load('fit', False) and self.offline_fit_by is not None:
+ self.fit_record['input'] = x
+ self.fit_record['output'] = res
+ return res
+
+ def online_init(self):
+ if self.b is None:
+ num_input = self.num_in
+ else:
+ num_input = self.num_in + 1
+ self.online_fit_by.register_target(feature_in=num_input, identifier=self.name)
+
+ def online_fit(self,
+ target: ArrayType,
+ fit_record: Dict[str, ArrayType]):
+ if not isinstance(target, (bm.ndarray, jnp.ndarray)):
+ raise MathError(f'"target" must be a tensor, but got {type(target)}')
+ x = fit_record['input']
+ y = fit_record['output']
+ if x.ndim != 2:
+ raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, '
+ f'num_feature), but we got {x.shape}')
+ if target.ndim != 2:
+ raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, '
+ f'num_feature), but we got {target.shape}')
+ if x.shape[0] != target.shape[0]:
+ raise ValueError(f'Batch size of the input and target data should be '
+ f'the same, while we got {x.shape[0]} != {target.shape[0]}.')
+ if target.shape[1] != y.shape[1]:
+ raise MathError(f'The output dimension of output and target data should be '
+ f'the same, while we got {target.shape[1]} != {y.shape[1]}')
+
+ # data
+ if self.b is not None:
+ x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1)
+
+ # fitting
+ dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name)
+
+ # assign trained weights
+ if self.b is None:
+ self.W += dW
+ else:
+ db, dW = jnp.split(dW, [1])
+ self.b += db[0]
+ self.W += dW
+
+ def offline_fit(self,
+ target: ArrayType,
+ fit_record: Dict[str, ArrayType]):
+ """The offline training interface for the Dense node."""
+ # data checking
+ if not isinstance(target, (bm.ndarray, jnp.ndarray)):
+ raise MathError(f'"targets" must be a tensor, but got {type(target)}')
+ xs = fit_record['input']
+ ys = fit_record['output']
+ if xs.ndim != 3:
+ raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, '
+ f'num_feature), but we got {xs.shape}')
+ if target.ndim != 3:
+ raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, '
+ f'num_feature), but we got {target.shape}')
+ if ys.shape != target.shape:
+ raise ValueError(f'The shapes of output and target data should be '
+ f'the same, while we got {ys.shape} != {target.shape}.')
+ if xs.shape[0] != target.shape[0]:
+ raise ValueError(f'Batch size of the input and target data should be '
+ f'the same, while we got {xs.shape[0]} != {target.shape[0]}.')
+ if xs.shape[1] != target.shape[1]:
+ raise MathError(f'The time dimension of input and target data should be '
+ f'the same, while we got {xs.shape[1]} != {target.shape[1]}')
+
+ # get input and target training data
+ if self.b is not None:
+ xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input)
+
+ # solve weights by offline training methods
+ weights = self.offline_fit_by(target, xs, ys)
+
+ # assign trained weights
+ if self.b is None:
+ self.W.value = weights
+ else:
+ bias, Wff = jnp.split(weights, [1])
+ self.W.value = Wff
+ self.b.value = bias[0]
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.W, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.W, bm.Variable):
+ self.tracing_variable('W', self.W, self.W.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max)
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max)
Linear = Dense
class Identity(Layer):
- r"""A placeholder identity operator that is argument-insensitive.
- """
+ r"""A placeholder identity operator that is argument-insensitive.
+ """
- def __init__(self, *args, **kwargs) -> None:
- super(Identity, self).__init__(*args, **kwargs)
+ def __init__(self, *args, **kwargs) -> None:
+ super(Identity, self).__init__(*args, **kwargs)
- def update(self, x):
- return x
+ def update(self, x):
+ return x
if ti is not None:
- # @numba.njit(nogil=True, fastmath=True, parallel=False)
- # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
- # out_w[:] = weight
- # for i in numba.prange(spike.shape[0]):
- # if spike[i]:
- # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max)
-
- @ti.kernel
- def _dense_on_post(
- old_w: ti.types.ndarray(ndim=2),
- post_spike: ti.types.ndarray(ndim=1),
- pre_trace: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- out_w: ti.types.ndarray(ndim=2)
- ):
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- num_pre, num_post = out_w.shape
-
- for i, j in ti.ndrange(num_pre, num_post):
- if post_spike[j]:
- new_value = out_w[i, j] + pre_trace[i]
- if new_value < w_min0:
- out_w[i, j] = w_min0
- elif new_value > w_max0:
- out_w[i, j] = w_max0
- else:
- out_w[i, j] = new_value
- else:
- out_w[i, j] = old_w[i, j]
-
-
- dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)
-
-
- # @numba.njit(nogil=True, fastmath=True, parallel=False)
- # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w):
- # out_w[:] = weight
- # for i in numba.prange(spike.shape[0]):
- # if spike[i]:
- # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max)
-
- @ti.kernel
- def _dense_on_pre(
- old_w: ti.types.ndarray(ndim=2),
- pre_spike: ti.types.ndarray(ndim=1),
- post_trace: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- out_w: ti.types.ndarray(ndim=2)
- ):
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- num_pre, num_post = out_w.shape
-
- for i, j in ti.ndrange(num_pre, num_post):
- if pre_spike[i]:
- new_value = out_w[i, j] + post_trace[j]
- if new_value < w_min0:
- out_w[i, j] = w_min0
- elif new_value > w_max0:
- out_w[i, j] = w_max0
- else:
- out_w[i, j] = new_value
- else:
- out_w[i, j] = old_w[i, j]
-
-
- dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)
+ # @numba.njit(nogil=True, fastmath=True, parallel=False)
+ # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
+ # out_w[:] = weight
+ # for i in numba.prange(spike.shape[0]):
+ # if spike[i]:
+ # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max)
+
+ @ti.kernel
+ def _dense_on_post(
+ old_w: ti.types.ndarray(ndim=2),
+ post_spike: ti.types.ndarray(ndim=1),
+ pre_trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=2)
+ ):
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ num_pre, num_post = out_w.shape
+
+ for i, j in ti.ndrange(num_pre, num_post):
+ if post_spike[j]:
+ new_value = out_w[i, j] + pre_trace[i]
+ if new_value < w_min0:
+ out_w[i, j] = w_min0
+ elif new_value > w_max0:
+ out_w[i, j] = w_max0
+ else:
+ out_w[i, j] = new_value
+ else:
+ out_w[i, j] = old_w[i, j]
+
+
+ dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)
+
+
+ # @numba.njit(nogil=True, fastmath=True, parallel=False)
+ # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w):
+ # out_w[:] = weight
+ # for i in numba.prange(spike.shape[0]):
+ # if spike[i]:
+ # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max)
+
+ @ti.kernel
+ def _dense_on_pre(
+ old_w: ti.types.ndarray(ndim=2),
+ pre_spike: ti.types.ndarray(ndim=1),
+ post_trace: ti.types.ndarray(ndim=1),
+ w_min: ti.types.ndarray(ndim=1),
+ w_max: ti.types.ndarray(ndim=1),
+ out_w: ti.types.ndarray(ndim=2)
+ ):
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ num_pre, num_post = out_w.shape
+
+ for i, j in ti.ndrange(num_pre, num_post):
+ if pre_spike[i]:
+ new_value = out_w[i, j] + post_trace[j]
+ if new_value < w_min0:
+ out_w[i, j] = w_min0
+ elif new_value > w_max0:
+ out_w[i, j] = w_max0
+ else:
+ out_w[i, j] = new_value
+ else:
+ out_w[i, j] = old_w[i, j]
+
+
+ dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)
else:
- dense_on_pre_prim = None
- dense_on_post_prim = None
+ dense_on_pre_prim = None
+ dense_on_post_prim = None
def dense_on_pre(weight, spike, trace, w_min, w_max):
- if dense_on_pre_prim is None:
- raise PackageMissingError.by_purpose('taichi', 'custom operators')
+ if dense_on_pre_prim is None:
+ raise PackageMissingError.by_purpose('taichi', 'custom operators')
+
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- w_min = jnp.atleast_1d(w_min)
- w_max = jnp.atleast_1d(w_max)
- return dense_on_pre_prim(weight, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
+ weight = bm.as_jax(weight)
+ spike = bm.as_jax(spike)
+ trace = bm.as_jax(trace)
+ w_min = bm.as_jax(w_min)
+ w_max = bm.as_jax(w_max)
+ return dense_on_pre_prim(weight, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
def dense_on_post(weight, spike, trace, w_min, w_max):
- if dense_on_post_prim is None:
- raise PackageMissingError.by_purpose('taichi', 'custom operators')
+ if dense_on_post_prim is None:
+ raise PackageMissingError.by_purpose('taichi', 'custom operators')
+
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- w_min = jnp.atleast_1d(w_min)
- w_max = jnp.atleast_1d(w_max)
- return dense_on_post_prim(weight, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
+ weight = bm.as_jax(weight)
+ spike = bm.as_jax(spike)
+ trace = bm.as_jax(trace)
+ w_min = bm.as_jax(w_min)
+ w_max = bm.as_jax(w_max)
+ return dense_on_post_prim(weight, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
class AllToAll(Layer, SupportSTDP):
- """Synaptic matrix multiplication with All2All connections.
-
- Args:
- num_pre: int. The number of neurons in the presynaptic neuron group.
- num_post: int. The number of neurons in the postsynaptic neuron group.
- weight: The synaptic weights.
- sharding: The sharding strategy.
- include_self: bool. Whether connect the neuron with at the same position.
- mode: Mode. The computing mode.
- name: str. The object name.
- """
-
- def __init__(
- self,
- num_pre: int,
- num_post: int,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- include_self: bool = True,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(mode=mode, name=name)
-
- self.num_pre = num_pre
- self.num_post = num_post
- self.include_self = include_self
- self.sharding = sharding
-
- weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding)
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, pre_val):
- if bm.ndim(self.weight) == 0: # weight is a scalar
- if isinstance(self.mode, bm.BatchingMode):
- assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.'
- post_val = bm.sum(pre_val, keepdims=True, axis=1)
- else:
- assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.'
- post_val = bm.sum(pre_val)
- if not self.include_self:
- if self.num_pre == self.num_post:
- post_val = post_val - pre_val
- elif self.num_pre > self.num_post:
- val = pre_val[:self.num_post]
- post_val = post_val - val
- else:
- val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)])
- post_val = post_val - val
- post_val = self.weight * post_val
-
- else: # weight is a matrix
- assert self.weight.ndim == 2, '"weight" must be a 2D matrix.'
- if not self.include_self:
- post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False)
- else:
- post_val = pre_val @ self.weight
- return post_val
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.weight, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
+ """Synaptic matrix multiplication with All2All connections.
+
+ Args:
+ num_pre: int. The number of neurons in the presynaptic neuron group.
+ num_post: int. The number of neurons in the postsynaptic neuron group.
+ weight: The synaptic weights.
+ sharding: The sharding strategy.
+ include_self: bool. Whether connect the neuron with at the same position.
+ mode: Mode. The computing mode.
+ name: str. The object name.
+ """
+
+ def __init__(
+ self,
+ num_pre: int,
+ num_post: int,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ include_self: bool = True,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(mode=mode, name=name)
+
+ self.num_pre = num_pre
+ self.num_post = num_post
+ self.include_self = include_self
+ self.sharding = sharding
+
+ weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding)
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, pre_val):
+ if bm.ndim(self.weight) == 0: # weight is a scalar
+ if isinstance(self.mode, bm.BatchingMode):
+ assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.'
+ post_val = bm.sum(pre_val, keepdims=True, axis=1)
+ else:
+ assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.'
+ post_val = bm.sum(pre_val)
+ if not self.include_self:
+ if self.num_pre == self.num_post:
+ post_val = post_val - pre_val
+ elif self.num_pre > self.num_post:
+ val = pre_val[:self.num_post]
+ post_val = post_val - val
+ else:
+ val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)])
+ post_val = post_val - val
+ post_val = self.weight * post_val
+
+ else: # weight is a matrix
+ assert self.weight.ndim == 2, '"weight" must be a 2D matrix.'
+ if not self.include_self:
+ post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False)
+ else:
+ post_val = pre_val @ self.weight
+ return post_val
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
class OneToOne(Layer, SupportSTDP):
- """Synaptic matrix multiplication with One2One connection.
-
- Args:
- num: int. The number of neurons.
- weight: The synaptic weight.
- sharding: The sharding strategy.
- mode: The computing mode.
- name: The object name.
-
- """
-
- def __init__(
- self,
- num: int,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(mode=mode, name=name)
-
- self.num = num
- self.sharding = sharding
-
- weight = init.parameter(weight, (self.num,), sharding=sharding)
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, pre_val):
- return pre_val * self.weight
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.weight, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value += spike * trace
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value += spike * trace
+ """Synaptic matrix multiplication with One2One connection.
+
+ Args:
+ num: int. The number of neurons.
+ weight: The synaptic weight.
+ sharding: The sharding strategy.
+ mode: The computing mode.
+ name: The object name.
+
+ """
+
+ def __init__(
+ self,
+ num: int,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(mode=mode, name=name)
+
+ self.num = num
+ self.sharding = sharding
+
+ weight = init.parameter(weight, (self.num,), sharding=sharding)
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, pre_val):
+ return pre_val * self.weight
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value += spike * trace
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value += spike * trace
class MaskedLinear(Layer, SupportSTDP):
- r"""Synaptic matrix multiplication with masked dense computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a dense matrix.
-
- >>> import brainpy as bp
- >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100),
- >>> weight=0.1)
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- mask_fun: Masking function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- mask_fun: Callable = Identity(),
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
- self.mask_fun = mask_fun
-
- # weight
- weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding)
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- # connection
- self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding)
-
- def update(self, x):
- return x @ self.mask_fun(self.weight * self.mask)
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if isinstance(self.weight, float):
- raise ValueError(f'Cannot update the weight of a constant node.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None:
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
- if on_post is not None:
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
+ r"""Synaptic matrix multiplication with masked dense computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a dense matrix.
+
+ >>> import brainpy as bp
+ >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100),
+ >>> weight=0.1)
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ mask_fun: Masking function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ mask_fun: Callable = Identity(),
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
+ self.mask_fun = mask_fun
+
+ # weight
+ weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding)
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ # connection
+ self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding)
+
+ def update(self, x):
+ return x @ self.mask_fun(self.weight * self.mask)
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None:
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
+ if on_post is not None:
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
class _CSRLayer(Layer, SupportSTDP):
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = True,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- assert sharding is None, 'Currently this model does not support sharding.'
- self.conn = conn
- self.sharding = sharding
- self.transpose = transpose
-
- # connection
- self.indices, self.indptr = self.conn.require('csr')
-
- # weight
- weight = init.parameter(weight, (self.indices.size,))
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def stdp_update(
- self,
- on_pre: Dict = None,
- on_post: Dict = None,
- w_min: numbers.Number = None,
- w_max: numbers.Number = None
- ):
- if bm.isscalar(self.weight):
- raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
- if self.weight.shape != self.indices.shape:
- raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.')
- if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
- if on_pre is not None: # update on presynaptic spike
- spike = on_pre['spike']
- trace = on_pre['trace']
- self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max)
- if on_post is not None: # update on postsynaptic spike
- if not hasattr(self, '_pre_ids'):
- with jax.ensure_compile_time_eval():
- self._pre_ids, self._post_indptr, self.w_indices = csr2csc(
- [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size)
- )
- spike = on_post['spike']
- trace = on_post['trace']
- self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr,
- self.w_indices, spike, trace, w_min, w_max)
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = True,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ assert sharding is None, 'Currently this model does not support sharding.'
+ self.conn = conn
+ self.sharding = sharding
+ self.transpose = transpose
+
+ # connection
+ self.indices, self.indptr = self.conn.require('csr')
+
+ # weight
+ weight = init.parameter(weight, (self.indices.size,))
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def stdp_update(
+ self,
+ on_pre: Dict = None,
+ on_post: Dict = None,
+ w_min: numbers.Number = None,
+ w_max: numbers.Number = None
+ ):
+ if bm.isscalar(self.weight):
+ raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
+ if self.weight.shape != self.indices.shape:
+ raise ValueError(
+ f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ if on_pre is not None: # update on presynaptic spike
+ spike = on_pre['spike']
+ trace = on_pre['trace']
+ self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min,
+ w_max)
+ if on_post is not None: # update on postsynaptic spike
+ if not hasattr(self, '_pre_ids'):
+ with jax.ensure_compile_time_eval():
+ self._pre_ids, self._post_indptr, self.w_indices = csr2csc(
+ [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size)
+ )
+ spike = on_post['spike']
+ trace = on_post['trace']
+ self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr,
+ self.w_indices, spike, trace, w_min, w_max)
class CSRLinear(_CSRLayer):
- r"""Synaptic matrix multiplication with CSR sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a CSR sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- method: str = None,
- transpose: bool = True,
- ):
- super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
- self.method = method
-
- def update(self, x):
- if x.ndim == 1:
- return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose)
- elif x.ndim > 1:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_csrmv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_csrmv(self, x):
- return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose)
+ r"""Synaptic matrix multiplication with CSR sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a CSR sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ method: str = None,
+ transpose: bool = True,
+ ):
+ super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
+ self.method = method
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose)
+ elif x.ndim > 1:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_csrmv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_csrmv(self, x):
+ return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose)
class EventCSRLinear(_CSRLayer):
- r"""Synaptic matrix multiplication with event CSR sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weight using a CSR sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = True,
- ):
- super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
-
- def update(self, x):
- if x.ndim == 1:
- return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num),
- transpose=self.transpose)
- elif x.ndim > 1:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_csrmv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_csrmv(self, x):
- return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
- shape=(self.conn.pre_num, self.conn.post_num),
- transpose=self.transpose)
+ r"""Synaptic matrix multiplication with event CSR sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weight using a CSR sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = True,
+ ):
+ super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num),
+ transpose=self.transpose)
+ elif x.ndim > 1:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_csrmv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_csrmv(self, x):
+ return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
+ shape=(self.conn.pre_num, self.conn.post_num),
+ transpose=self.transpose)
if ti is not None:
- @ti.kernel
- def _csr_on_pre_update(
- old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1)
- spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
- trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
- w_min: ti.types.ndarray(ndim=1), # scalar
- w_max: ti.types.ndarray(ndim=1), # scalar
- out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
- ):
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- num_pre = spike.shape[0]
- for i_pre in range(num_pre):
- if spike[i_pre]:
- for i_syn in range(indptr[i_pre], indptr[i_pre + 1]):
- out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0)
- else:
- for i_syn in range(indptr[i_pre], indptr[i_pre + 1]):
- out_w[i_syn] = old_w[i_syn]
-
-
- csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)
-
-
- @ti.kernel
- def _coo_on_pre_update(
- old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
- post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
- w_min: ti.types.ndarray(ndim=1), # scalar
- w_max: ti.types.ndarray(ndim=1), # scalar
- out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
- ):
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- num_syn = old_w.shape[0]
- for i_syn in range(num_syn):
- if pre_spike[pre_ids[i_syn]]: # pre spike
- out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0)
- else:
- out_w[i_syn] = old_w[i_syn]
-
-
- coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)
-
-
- @ti.kernel
- def _coo_on_post_update(
- old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
- pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
- w_min: ti.types.ndarray(ndim=1), # scalar
- w_max: ti.types.ndarray(ndim=1), # scalar
- out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
- ):
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- num_syn = old_w.shape[0]
- for i_syn in range(num_syn):
- if post_spike[post_ids[i_syn]]: # pre spike
- out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0)
- else:
- out_w[i_syn] = old_w[i_syn]
-
-
- coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)
-
-
- # @numba.njit(nogil=True, fastmath=True, parallel=False)
- # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
- # out_w[:] = w
- # w_min = w_min[()]
- # w_max = w_max[()]
- # for i in numba.prange(spike.shape[0]): # post id
- # if spike[i]:
- # for k in range(indptr[i], indptr[i + 1]):
- # j = post_ids[k] # pre id
- # l = w_ids[k] # syn id
- # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max)
-
- @ti.kernel
- def _csc_on_post_update(
- old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1)
- w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
- pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
- w_min: ti.types.ndarray(ndim=1), # scalar
- w_max: ti.types.ndarray(ndim=1), # scalar
- out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
- ):
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- num_post = post_spike.shape[0]
- for i_post in range(num_post):
- if post_spike[i_post]:
- for k in range(indptr[i_post], indptr[i_post + 1]):
- i_syn = w_ids[k] # syn id
- out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0)
- else:
- for k in range(indptr[i_post], indptr[i_post + 1]):
- i_syn = w_ids[k] # syn id
- out_w[i_syn] = old_w[i_syn]
-
-
- csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)
+ @ti.kernel
+ def _csr_on_pre_update(
+ old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1)
+ spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
+ trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
+ w_min: ti.types.ndarray(ndim=1), # scalar
+ w_max: ti.types.ndarray(ndim=1), # scalar
+ out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
+ ):
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ num_pre = spike.shape[0]
+ for i_pre in range(num_pre):
+ if spike[i_pre]:
+ for i_syn in range(indptr[i_pre], indptr[i_pre + 1]):
+ out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0)
+ else:
+ for i_syn in range(indptr[i_pre], indptr[i_pre + 1]):
+ out_w[i_syn] = old_w[i_syn]
+
+
+ csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)
+
+
+ @ti.kernel
+ def _coo_on_pre_update(
+ old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
+ post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
+ w_min: ti.types.ndarray(ndim=1), # scalar
+ w_max: ti.types.ndarray(ndim=1), # scalar
+ out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
+ ):
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ num_syn = old_w.shape[0]
+ for i_syn in range(num_syn):
+ if pre_spike[pre_ids[i_syn]]: # pre spike
+ out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0)
+ else:
+ out_w[i_syn] = old_w[i_syn]
+
+
+ coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)
+
+
+ @ti.kernel
+ def _coo_on_post_update(
+ old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
+ pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
+ w_min: ti.types.ndarray(ndim=1), # scalar
+ w_max: ti.types.ndarray(ndim=1), # scalar
+ out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
+ ):
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ num_syn = old_w.shape[0]
+ for i_syn in range(num_syn):
+ if post_spike[post_ids[i_syn]]: # pre spike
+ out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0)
+ else:
+ out_w[i_syn] = old_w[i_syn]
+
+
+ coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)
+
+
+ # @numba.njit(nogil=True, fastmath=True, parallel=False)
+ # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
+ # out_w[:] = w
+ # w_min = w_min[()]
+ # w_max = w_max[()]
+ # for i in numba.prange(spike.shape[0]): # post id
+ # if spike[i]:
+ # for k in range(indptr[i], indptr[i + 1]):
+ # j = post_ids[k] # pre id
+ # l = w_ids[k] # syn id
+ # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max)
+
+ @ti.kernel
+ def _csc_on_post_update(
+ old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1)
+ w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
+ pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
+ w_min: ti.types.ndarray(ndim=1), # scalar
+ w_max: ti.types.ndarray(ndim=1), # scalar
+ out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
+ ):
+ w_min0 = w_min[0]
+ w_max0 = w_max[0]
+ num_post = post_spike.shape[0]
+ for i_post in range(num_post):
+ if post_spike[i_post]:
+ for k in range(indptr[i_post], indptr[i_post + 1]):
+ i_syn = w_ids[k] # syn id
+ out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0)
+ else:
+ for k in range(indptr[i_post], indptr[i_post + 1]):
+ i_syn = w_ids[k] # syn id
+ out_w[i_syn] = old_w[i_syn]
+
+
+ csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)
else:
- csr_on_pre_update_prim = None
- coo_on_pre_update_prim = None
- csc_on_post_update_prim = None
+ csr_on_pre_update_prim = None
+ coo_on_pre_update_prim = None
+ csc_on_post_update_prim = None
def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
- if csr_on_pre_update_prim is None:
- raise PackageMissingError.by_purpose('taichi', 'customized operators')
-
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- w_min = jnp.atleast_1d(w_min)
- w_max = jnp.atleast_1d(w_max)
- return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
+ if csr_on_pre_update_prim is None:
+ raise PackageMissingError.by_purpose('taichi', 'customized operators')
+
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
+
+ w = bm.as_jax(w)
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ spike = bm.as_jax(spike)
+ trace = bm.as_jax(trace)
+ w_min = bm.as_jax(w_min)
+ w_max = bm.as_jax(w_max)
+ return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None):
- if coo_on_pre_update_prim is None:
- raise PackageMissingError.by_purpose('taichi', 'customized operators')
+ if coo_on_pre_update_prim is None:
+ raise PackageMissingError.by_purpose('taichi', 'customized operators')
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- w_min = jnp.atleast_1d(w_min)
- w_max = jnp.atleast_1d(w_max)
- return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
+ w = bm.as_jax(w)
+ pre_ids = bm.as_jax(pre_ids)
+ post_ids = bm.as_jax(post_ids)
+ spike = bm.as_jax(spike)
+ trace = bm.as_jax(trace)
+ w_min = bm.as_jax(w_min)
+ w_max = bm.as_jax(w_max)
-def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None):
- if csc_on_post_update_prim is None:
- raise PackageMissingError.by_purpose('taichi', 'customized operators')
+ return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
- if w_min is None:
- w_min = -np.inf
- if w_max is None:
- w_max = np.inf
- w_min = jnp.atleast_1d(w_min)
- w_max = jnp.atleast_1d(w_max)
- return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max,
- outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
+
+def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None):
+ if csc_on_post_update_prim is None:
+ raise PackageMissingError.by_purpose('taichi', 'customized operators')
+
+ if w_min is None:
+ w_min = -np.inf
+ if w_max is None:
+ w_max = np.inf
+ w_min = jnp.atleast_1d(w_min)
+ w_max = jnp.atleast_1d(w_max)
+
+ w = bm.as_jax(w)
+ post_ids = bm.as_jax(post_ids)
+ indptr = bm.as_jax(indptr)
+ w_ids = bm.as_jax(w_ids)
+ post_spike = bm.as_jax(post_spike)
+ pre_trace = bm.as_jax(pre_trace)
+ w_min = bm.as_jax(w_min)
+ w_max = bm.as_jax(w_max)
+ return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max,
+ outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
class CSCLinear(Layer):
- r"""Synaptic matrix multiplication with CSC sparse computation.
+ r"""Synaptic matrix multiplication with CSC sparse computation.
- It performs the computation of:
+ It performs the computation of:
- .. math::
+ .. math::
- y = x @ M
+ y = x @ M
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a CSC sparse matrix.
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a CSC sparse matrix.
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
class BcsrMM(Layer):
- r"""Synaptic matrix multiplication with BCSR sparse computation.
+ r"""Synaptic matrix multiplication with BCSR sparse computation.
- It performs the computation of:
+ It performs the computation of:
- .. math::
+ .. math::
- y = x @ M
+ y = x @ M
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a BCSR sparse matrix.
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a BCSR sparse matrix.
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
class BcscMM(Layer):
- r"""Synaptic matrix multiplication with BCSC sparse computation.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
- :math:`M` the synaptic weight using a BCSC sparse matrix.
-
- Args:
- conn: TwoEndConnector. The connection.
- weight: Synaptic weights. Can be a scalar, array, or callable function.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- conn: connect.TwoEndConnector,
- weight: Union[float, ArrayType, Callable],
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- assert isinstance(conn, connect.TwoEndConnector)
- self.conn = conn
- self.sharding = sharding
-
-
-class JitFPHomoLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is the same :math:`weight`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- weight: float. The synaptic value at each position.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- weight: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = False,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class JitFPUniformLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_low: float. The lowest value of the uniform distribution.
- w_high: float. The highest value of the uniform distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_low: float,
- w_high: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = False,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_low = w_low
- self.w_high = w_high
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class JitFPNormalLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_mu: float. The center of the normal distribution.
- w_sigma: float. The standard variance of the normal distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_mu: float,
- w_sigma: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- transpose: bool = False,
- atomic: bool = False,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_mu = w_mu
- self.w_sigma = w_sigma
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class EventJitFPHomoLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is the same :math:`weight`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- weight: float. The synaptic value at each position.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- weight: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = True,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 1000000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- self.weight = weight
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
+ r"""Synaptic matrix multiplication with BCSC sparse computation.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
+ :math:`M` the synaptic weight using a BCSC sparse matrix.
+
+ Args:
+ conn: TwoEndConnector. The connection.
+ weight: Synaptic weights. Can be a scalar, array, or callable function.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ weight: Union[float, ArrayType, Callable],
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
+ self.sharding = sharding
+
+
+class JitLinear(Layer):
+ def get_conn_matrix(self):
+ pass
+
+
+class JitFPHomoLayer(JitLinear):
+ def get_conn_matrix(self):
+ return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class JitFPUniformLayer(JitLinear):
+ def get_conn_matrix(self):
+ return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class JitFPNormalLayer(JitLinear):
+ def get_conn_matrix(self):
+ return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class JitFPHomoLinear(JitFPHomoLayer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is the same :math:`weight`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ weight: float. The synaptic value at each position.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ weight: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = False,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
-class EventJitFPUniformLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_low: float. The lowest value of the uniform distribution.
- w_high: float. The highest value of the uniform distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_low: float,
- w_high: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- transpose: bool = False,
- atomic: bool = True,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_low = w_low
- self.w_high = w_high
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+class JitFPUniformLinear(JitFPUniformLayer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_low: float. The lowest value of the uniform distribution.
+ w_high: float. The highest value of the uniform distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_low: float,
+ w_high: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = False,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_low = w_low
+ self.w_high = w_high
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
-
-
-class EventJitFPNormalLinear(Layer):
- r"""Synaptic matrix multiplication with the just-in-time connectivity.
-
- It performs the computation of:
-
- .. math::
-
- y = x @ M
-
- where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
- :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
- Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
- and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
-
- Args:
- num_in: int. The number of the input feature. A positive integer.
- num_out: int. The number of the input feature. A positive integer.
- prob: float. The connectivity probability.
- w_mu: float. The center of the normal distribution.
- w_sigma: float. The standard variance of the normal distribution.
- seed: int. The random seed used to keep the reproducibility of the connectivity.
- transpose: bool. Transpose the JIT matrix or not. Default False.
- atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
- May be changed in the future.
- sharding: The sharding strategy.
- mode: The synaptic computing mode.
- name: The synapse model name.
- """
-
- def __init__(
- self,
- num_in: int,
- num_out: int,
- prob: float,
- w_mu: float,
- w_sigma: float,
- seed: Optional[int] = None,
- sharding: Optional[Sharding] = None,
- transpose: bool = False,
- atomic: bool = True,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- self.prob = prob
- self.sharding = sharding
- self.transpose = transpose
- self.seed = np.random.randint(0, 100000) if seed is None else seed
- self.atomic = atomic
- self.num_in = num_in
- self.num_out = num_out
-
- # weight
- self.w_mu = w_mu
- self.w_sigma = w_sigma
-
- def update(self, x):
- if x.ndim == 1:
- return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class JitFPNormalLinear(JitFPNormalLayer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_mu: float. The center of the normal distribution.
+ w_sigma: float. The standard variance of the normal distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_mu: float,
+ w_sigma: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ transpose: bool = False,
+ atomic: bool = False,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_mu = w_mu
+ self.w_sigma = w_sigma
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
- elif x.ndim == 2:
- return jax.vmap(self._batch_mv)(x)
- elif x.ndim > 2:
- shapes = x.shape[:-1]
- x = bm.flatten(x, end_dim=-2)
- y = jax.vmap(self._batch_mv)(x)
- return bm.reshape(y, shapes + (y.shape[-1],))
- else:
- raise ValueError
-
- def _batch_mv(self, x):
- return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
- shape=(self.num_out, self.num_in),
- transpose=self.transpose,
- outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class EventJitFPHomoLinear(JitFPHomoLayer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is the same :math:`weight`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ weight: float. The synaptic value at each position.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ weight: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = True,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 1000000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ self.weight = weight
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class EventJitFPUniformLinear(JitFPUniformLayer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_low: float. The lowest value of the uniform distribution.
+ w_high: float. The highest value of the uniform distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_low: float,
+ w_high: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ transpose: bool = False,
+ atomic: bool = True,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_low = w_low
+ self.w_high = w_high
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+
+
+class EventJitFPNormalLinear(JitFPNormalLayer):
+ r"""Synaptic matrix multiplication with the just-in-time connectivity.
+
+ It performs the computation of:
+
+ .. math::
+
+ y = x @ M
+
+ where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes,
+ :math:`M` the synaptic weights which has the fixed sparse connectivity and weights.
+ Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`,
+ and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`.
+
+ Args:
+ num_in: int. The number of the input feature. A positive integer.
+ num_out: int. The number of the input feature. A positive integer.
+ prob: float. The connectivity probability.
+ w_mu: float. The center of the normal distribution.
+ w_sigma: float. The standard variance of the normal distribution.
+ seed: int. The random seed used to keep the reproducibility of the connectivity.
+ transpose: bool. Transpose the JIT matrix or not. Default False.
+ atomic: bool. Compute the post-synaptic value with the atomic summation. Default False.
+ May be changed in the future.
+ sharding: The sharding strategy.
+ mode: The synaptic computing mode.
+ name: The synapse model name.
+ """
+
+ def __init__(
+ self,
+ num_in: int,
+ num_out: int,
+ prob: float,
+ w_mu: float,
+ w_sigma: float,
+ seed: Optional[int] = None,
+ sharding: Optional[Sharding] = None,
+ transpose: bool = False,
+ atomic: bool = True,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ self.prob = prob
+ self.sharding = sharding
+ self.transpose = transpose
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
+ self.atomic = atomic
+ self.num_in = num_in
+ self.num_out = num_out
+
+ # weight
+ self.w_mu = w_mu
+ self.w_sigma = w_sigma
+
+ def update(self, x):
+ if x.ndim == 1:
+ return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
+ elif x.ndim == 2:
+ return jax.vmap(self._batch_mv)(x)
+ elif x.ndim > 2:
+ shapes = x.shape[:-1]
+ x = bm.flatten(x, end_dim=-2)
+ y = jax.vmap(self._batch_mv)(x)
+ return bm.reshape(y, shapes + (y.shape[-1],))
+ else:
+ raise ValueError
+
+ def _batch_mv(self, x):
+ return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
+ shape=(self.num_out, self.num_in),
+ transpose=self.transpose,
+ outdim_parallel=not self.atomic)
diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py
index 7a0fa57af..514706419 100644
--- a/brainpy/_src/dnn/tests/test_activation.py
+++ b/brainpy/_src/dnn/tests/test_activation.py
@@ -4,6 +4,7 @@
import brainpy.math as bm
+
class Test_Activation(parameterized.TestCase):
@parameterized.product(
diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py
index 05f523622..af38a355f 100644
--- a/brainpy/_src/dnn/tests/test_conv_layers.py
+++ b/brainpy/_src/dnn/tests/test_conv_layers.py
@@ -1,12 +1,17 @@
# -*- coding: utf-8 -*-
+import platform
import jax.numpy as jnp
+import pytest
from absl.testing import absltest
from absl.testing import parameterized
import brainpy as bp
import brainpy.math as bm
+if platform.system() == 'Darwin':
+ pytest.skip('skip Mac OS', allow_module_level=True)
+
class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py
index 6cc445383..2fd7df2dd 100644
--- a/brainpy/_src/dnn/tests/test_linear.py
+++ b/brainpy/_src/dnn/tests/test_linear.py
@@ -1,14 +1,11 @@
import pytest
from absl.testing import absltest
from absl.testing import parameterized
+import jax.numpy as jnp
import brainpy as bp
import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
class TestLinear(parameterized.TestCase):
@@ -104,11 +101,11 @@ def test_CSRLinear(self, conn):
bm.random.seed()
f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
x = bm.random.random((16, 100))
- y = f(x)
+ y = f(jnp.asarray(x))
self.assertTrue(y.shape == (16, 100))
x = bm.random.random((100,))
- y = f(x)
+ y = f(jnp.asarray(x))
self.assertTrue(y.shape == (100,))
bm.clear_buffer_memory()
@@ -123,10 +120,10 @@ def test_EventCSRLinear(self, conn):
bm.random.seed()
f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal())
x = bm.random.random((16, 100))
- y = f(x)
+ y = f(jnp.asarray(x))
self.assertTrue(y.shape == (16, 100))
x = bm.random.random((100,))
- y = f(x)
+ y = f(jnp.asarray(x))
self.assertTrue(y.shape == (100,))
bm.clear_buffer_memory()
@@ -141,6 +138,11 @@ def test_JitFPHomoLinear(self, prob, weight, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))
+
+ conn_matrix = f.get_conn_matrix()
+ self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
+ # print(conn_matrix.shape)
+ # self.assertTrue(conn_matrix.shape == (200, 100))
bm.clear_buffer_memory()
@parameterized.product(
@@ -155,6 +157,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))
+
+ conn_matrix = f.get_conn_matrix()
+ self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()
@parameterized.product(
@@ -169,6 +174,9 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))
+
+ conn_matrix = f.get_conn_matrix()
+ self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()
@parameterized.product(
@@ -179,11 +187,15 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
def test_EventJitFPHomoLinear(self, prob, weight, shape):
bm.random.seed()
f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123)
- y = f(bm.random.random(shape + (100,)) < 0.1)
+ x = bm.random.random(shape + (100,)) < 0.1
+ y = f(x)
self.assertTrue(y.shape == shape + (200,))
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))
+
+ conn_matrix = f.get_conn_matrix()
+ self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()
@parameterized.product(
@@ -195,11 +207,15 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape):
def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
bm.random.seed()
f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123)
- y = f(bm.random.random(shape + (100,)) < 0.1)
+ x = bm.random.random(shape + (100,)) < 0.1
+ y = f(x)
self.assertTrue(y.shape == shape + (200,))
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))
+
+ conn_matrix = f.get_conn_matrix()
+ self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()
@parameterized.product(
@@ -211,11 +227,15 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
bm.random.seed()
f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123)
- y = f(bm.random.random(shape + (100,)) < 0.1)
+ x = bm.random.random(shape + (100,)) < 0.1
+ y = f(x)
self.assertTrue(y.shape == shape + (200,))
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))
+
+ conn_matrix = f.get_conn_matrix()
+ self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()
diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py
index 10e9eeda2..eb87c201d 100644
--- a/brainpy/_src/dnn/tests/test_mode.py
+++ b/brainpy/_src/dnn/tests/test_mode.py
@@ -4,10 +4,6 @@
import brainpy as bp
import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
class Test_Conv(parameterized.TestCase):
diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py
index 18d9d9dc9..8fab94dec 100644
--- a/brainpy/_src/dyn/projections/tests/test_STDP.py
+++ b/brainpy/_src/dyn/projections/tests/test_STDP.py
@@ -6,10 +6,6 @@
import brainpy as bp
import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
bm.set_platform('cpu')
diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py
index eec2c9459..8bf2c150e 100644
--- a/brainpy/_src/dyn/projections/tests/test_aligns.py
+++ b/brainpy/_src/dyn/projections/tests/test_aligns.py
@@ -5,10 +5,6 @@
import brainpy as bp
import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
diff --git a/brainpy/_src/dyn/rates/tests/test_nvar.py b/brainpy/_src/dyn/rates/tests/test_nvar.py
index 38b578a6c..24659815c 100644
--- a/brainpy/_src/dyn/rates/tests/test_nvar.py
+++ b/brainpy/_src/dyn/rates/tests/test_nvar.py
@@ -11,7 +11,7 @@ class Test_NVAR(parameterized.TestCase):
def test_NVAR(self,mode):
bm.random.seed()
input=bm.random.randn(1,5)
- layer=bp.dnn.NVAR(num_in=5,
+ layer=bp.dyn.NVAR(num_in=5,
delay=10,
mode=mode)
if mode in [bm.NonBatchingMode()]:
diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
index d068f2079..0b371bcbf 100644
--- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
+++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
@@ -7,10 +7,6 @@
import brainpy as bp
import brainpy.math as bm
from brainpy._src.dynold.synapses import abstract_models
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
class Test_Abstract_Synapse(parameterized.TestCase):
diff --git a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py
index 01a315261..b48cb5b71 100644
--- a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py
+++ b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py
@@ -6,10 +6,6 @@
import brainpy as bp
import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
biological_models = [
bp.synapses.AMPA,
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index a6fcc16a7..c20fc414b 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -806,7 +806,7 @@ def __init__(
else:
# should be a list/tuple/array of int
# do not check again
- if not isinstance(idx, collections.Iterable):
+ if not isinstance(idx, collections.abc.Iterable):
raise TypeError('Should be an iterable object of int.')
size.append(len(idx))
size += list(target.varshape[len(self.index):])
diff --git a/brainpy/_src/initialize/tests/test_decay_inits.py b/brainpy/_src/initialize/tests/test_decay_inits.py
index bbab6d26d..22e1fa023 100644
--- a/brainpy/_src/initialize/tests/test_decay_inits.py
+++ b/brainpy/_src/initialize/tests/test_decay_inits.py
@@ -14,8 +14,8 @@
# visualization
def mat_visualize(matrix, cmap=None):
if cmap is None:
- cmap = plt.cm.get_cmap('coolwarm')
- plt.cm.get_cmap('coolwarm')
+ cmap = plt.colormaps.get_cmap('coolwarm')
+ plt.colormaps.get_cmap('coolwarm')
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
diff --git a/brainpy/_src/integrators/_jaxpr_to_source_code.py b/brainpy/_src/integrators/_jaxpr_to_source_code.py
new file mode 100644
index 000000000..3fa1d9006
--- /dev/null
+++ b/brainpy/_src/integrators/_jaxpr_to_source_code.py
@@ -0,0 +1,1132 @@
+# Modified from: https://github.com/dlwh/jax_sourceror
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import ast
+import enum
+import warnings
+from collections.abc import MutableMapping, MutableSet
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from typing import Callable, Union
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from jax._src.sharding_impls import UNSPECIFIED
+from jax.core import Literal, Var, Jaxpr
+
+__all__ = [
+ 'fn_to_python_code',
+ 'jaxpr_to_python_code',
+]
+
+
+class IdentitySet(MutableSet):
+ """Set that compares objects by identity.
+
+ This is a set that compares objects by identity instead of equality. It is
+ useful for storing objects that are not hashable or that should be compared
+ by identity.
+
+ This is a mutable set, but it does not support the ``__hash__`` method and
+ therefore cannot be used as a dictionary key or as an element of another set.
+ """
+
+ def __init__(self, iterable=None):
+ self._data = {}
+ if iterable is not None:
+ self.update(iterable)
+
+ def __contains__(self, value):
+ return id(value) in self._data
+
+ def __iter__(self):
+ return iter(self._data.values())
+
+ def __len__(self):
+ return len(self._data)
+
+ def add(self, value):
+ self._data[id(value)] = value
+
+ def discard(self, value):
+ self._data.pop(id(value), None)
+
+ def __repr__(self):
+ return f"IdentitySet({list(repr(x) for x in self._data.values())})"
+
+ def __str__(self):
+ return f"IdentitySet({list(str(x) for x in self._data.values())})"
+
+
+class IdentityMap(MutableMapping):
+ """Map that compares keys by identity.
+
+ This is a map that compares keys by identity instead of equality. It is
+ useful for storing objects that are not hashable or that should be compared
+ by identity.
+
+ This is a mutable mapping, but it does not support the ``__hash__`` method
+ and therefore cannot be used as a dictionary key or as an element of another
+ set.
+ """
+
+ def __init__(self, iterable=None):
+ self._data = {}
+ if iterable is not None:
+ self.update(iterable)
+
+ def __contains__(self, key):
+ return id(key) in self._data
+
+ def __getitem__(self, key):
+ return self._data[id(key)]
+
+ def __setitem__(self, key, value):
+ self._data[id(key)] = value
+
+ def __delitem__(self, key):
+ del self._data[id(key)]
+
+ def __iter__(self):
+ return iter(self._data.values())
+
+ def __len__(self):
+ return len(self._data)
+
+ def __repr__(self):
+ return f"IdentityMap({list(repr(x) for x in self._data.values())})"
+
+ def __str__(self):
+ return f"IdentityMap({list(str(x) for x in self._data.values())})"
+
+
+@dataclass
+class SourcerorState:
+ """State for the auto-minimizer. Basically just in charge of naming variables."""
+ _var_names: IdentityMap[Var, str] = field(default_factory=IdentityMap)
+ _skolem_count: int = 0
+
+ def name(self, var, ctx=ast.Load()) -> ast.Name:
+ return ast.Name(id=self.str_name(var), ctx=ctx)
+
+ def str_name(self, var: Var):
+ # Names things in a way vaguely compatible with
+ # JAX's naming scheme, which is 'a'-'z' followed
+ # by 'aa'-'az' etc.
+ if var in self._var_names:
+ return self._var_names[var]
+ else:
+ cur_count = len(self._var_names)
+ name = ""
+ while cur_count >= 26:
+ name += chr(ord('a') + cur_count % 26)
+ cur_count //= 26
+
+ name += chr(ord('a') + cur_count)
+
+ name = name[::-1]
+
+ self._var_names[var] = name
+
+ return name
+
+ def skolem(self, prefix: str):
+ self._skolem_count += 1
+ return f"{prefix}_{self._skolem_count}"
+
+
+prefix_imports = set()
+
+
+@contextmanager
+def catch_imports():
+ try:
+ prefix_imports.clear()
+ yield
+ finally:
+ prefix_imports.clear()
+
+
+def fn_to_python_code(fn, *args, **kwargs):
+ """
+ Given a function which is defined by jax primitives and the function arguments,
+ return the Python code that would be generated by JAX for that function.
+
+ :param fn: The function to generate code for
+ :param args: The positional arguments to the function
+ :param kwargs: The keyword arguments to the function
+ :return: The Python code that would be generated by JAX for that function
+ """
+ closed_jaxpr = jax.make_jaxpr(fn)(*args, **kwargs)
+ jaxpr = constant_fold_jaxpr(closed_jaxpr.jaxpr)
+ state = SourcerorState()
+ try:
+ name = fn.__name__
+ except AttributeError:
+ name = "unknown"
+ with catch_imports():
+ node = jaxpr_to_py_ast(state, jaxpr, fn_name=name)
+ node = _maybe_wrap_fn_for_leaves(node, fn, len(args) + len(kwargs))
+ ast.fix_missing_locations(node)
+ source = ast.unparse(node)
+ if len(prefix_imports):
+ source = "\n".join(prefix_imports) + "\n\n" + source
+ return source
+
+
+def jaxpr_to_python_code(jaxpr: jax.core.Jaxpr,
+ fn_name: str = "generated_function"):
+ """
+ Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.
+
+ :param jaxpr: The jaxpr to generate code.
+ :param fn_name: The name of the function to generate code.
+ :return: The Python code that would be generated by JAX for that jaxpr
+ """
+ jaxpr = constant_fold_jaxpr(jaxpr)
+ state = SourcerorState()
+ with catch_imports():
+ node = jaxpr_to_py_ast(state, jaxpr, fn_name=fn_name)
+ ast.fix_missing_locations(node)
+ source = ast.unparse(node)
+ if len(prefix_imports):
+ source = "\n".join(prefix_imports) + "\n\n" + source
+ return source
+
+
+def register_prim_handler(prim_name, handler):
+ """
+ Register a handler for a primitive for automin
+ :param prim_name:
+ :param handler:
+ :return:
+ """
+ if prim_name in prim_to_python:
+ warnings.warn(f"Overwriting handler for primitive {prim_name}")
+ prim_to_python[prim_name] = handler
+
+
+def register_prim_as(prim_name):
+ """
+ Decorator to register a handler for a primitive.
+
+ :param prim_name:
+ :return:
+ """
+
+ def decorator(fn):
+ register_prim_handler(prim_name, fn)
+ return fn
+
+ return decorator
+
+
+def _assign_stmt(call_expr: Callable):
+ """
+ Create a handler for a primitive that is a simple assignment.
+ :param call_expr:
+ :return:
+ """
+
+ def binop_fn(state, eqn):
+ invars = [_astify_atom(state, v) for v in eqn.invars]
+ outvars = _astify_outvars(state, eqn.outvars)
+ return ast.Assign(
+ outvars,
+ call_expr(
+ *invars,
+ **{k: _astify_value(v) for k, v in eqn.params.items()}
+ )
+ )
+
+ return binop_fn
+
+
+def _binop_fn(op: ast.operator):
+ return _assign_stmt(lambda x, y: ast.BinOp(left=x, op=op, right=y))
+
+
+def _cmpop_fn(op: ast.cmpop):
+ return _assign_stmt(lambda x, y: ast.Compare(left=x, ops=[op], comparators=[y]))
+
+
+def normal_fn(fn_name):
+ """
+ Create a handler for a normal function call.
+ :param fn_name:
+ :return:
+ """
+ return _assign_stmt(
+ lambda *args, **kwargs: ast.Call(
+ func=ast.Name(id=fn_name, ctx=ast.Load()),
+ args=list(args),
+ keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()]
+ )
+ )
+
+
+def _reduce_fn(fn_name: str):
+ def reduce_fn_inner(state: SourcerorState, eqn):
+ invars = [_astify_atom(state, v) for v in eqn.invars]
+ outvars = _astify_outvars(state, eqn.outvars)
+ if eqn.params:
+ params = eqn.params.copy()
+ params['axis'] = tuple(params['axes'])
+ del params['axes']
+ call_op = ast.Call(
+ func=ast.Name(id=fn_name, ctx=ast.Load()),
+ args=invars,
+ keywords=[ast.keyword(arg=k, value=_astify_value(v)) for k, v in params.items()]
+ )
+ else:
+ call_op = ast.Call(
+ func=ast.Name(id=fn_name, ctx=ast.Load()),
+ args=invars,
+ keywords=[]
+ )
+
+ return ast.Assign(outvars, call_op)
+
+ return reduce_fn_inner
+
+
+prim_to_python = dict()
+
+register_prim_handler('add', _binop_fn(ast.Add()))
+register_prim_handler('sub', _binop_fn(ast.Sub()))
+register_prim_handler('mul', _binop_fn(ast.Mult()))
+register_prim_handler('div', _binop_fn(ast.Div()))
+register_prim_handler('neg', normal_fn('jax.lax.neg'))
+register_prim_handler('lt', _cmpop_fn(ast.Lt()))
+register_prim_handler('gt', _cmpop_fn(ast.Gt()))
+register_prim_handler('le', _cmpop_fn(ast.LtE()))
+register_prim_handler('ge', _cmpop_fn(ast.GtE()))
+register_prim_handler('eq', _cmpop_fn(ast.Eq()))
+register_prim_handler('ne', _cmpop_fn(ast.NotEq()))
+register_prim_handler('min', normal_fn('jax.lax.min'))
+register_prim_handler('max', normal_fn('jax.lax.max'))
+register_prim_handler('select_n', normal_fn('jax.lax.select_n'))
+register_prim_handler('squeeze', normal_fn('jax.lax.squeeze'))
+register_prim_handler('broadcast', normal_fn('jax.lax.broadcast'))
+register_prim_handler('reduce_sum', _reduce_fn('jax.numpy.sum'))
+register_prim_handler('transpose', normal_fn('jax.lax.transpose'))
+
+
+def _maybe_wrap_fn_for_leaves(node, f, num_args):
+ if len(node.args.args) == num_args:
+ return node
+
+ wrapped_node = ast.FunctionDef(
+ name=f.__name__,
+ args=ast.arguments(
+ args=[],
+ vararg=ast.arg(arg="args", annotation=None),
+ kwarg=ast.arg(arg="kwargs", annotation=None),
+ kwonlyargs=[], kw_defaults=[], defaults=[],
+ posonlyargs=[]
+ ),
+ body=[
+ node,
+ ast.Return(
+ ast.Call(
+ func=ast.Name(id=node.name, ctx=ast.Load()),
+ args=[
+ ast.Starred(
+ ast.Call(
+ func=ast.Attribute(value=ast.Name(id="jax", ctx=ast.Load()),
+ attr="tree_leaves",
+ ctx=ast.Load()),
+ args=[ast.Tuple(elts=[ast.Name(id="args", ctx=ast.Load()),
+ ast.Name(id="kwargs", ctx=ast.Load())],
+ ctx=ast.Load())],
+ keywords=[]
+ )
+ )
+ ],
+ keywords=[]
+ )
+ ),
+ ],
+ decorator_list=[]
+ )
+
+ return wrapped_node
+
+
+def jaxpr_to_py_ast(state: SourcerorState,
+ jaxpr: jax.core.Jaxpr,
+ fn_name: str = "function"):
+ # Generate argument declarations
+ ast_args = [ast.arg(arg=state.str_name(var), annotation=None)
+ for var in jaxpr.invars]
+ ast_args = ast.arguments(args=ast_args,
+ vararg=None,
+ kwonlyargs=[],
+ kw_defaults=[],
+ kwarg=None,
+ defaults=[],
+ posonlyargs=[])
+
+ stmts = []
+
+ # Generate body of the function
+ for eqn in jaxpr.eqns:
+ prim = str(eqn.primitive)
+ if prim in prim_to_python:
+ eqn_stmts = prim_to_python[prim](state, eqn)
+ else:
+ eqn_stmts = normal_fn(prim)(state, eqn)
+
+ if isinstance(eqn_stmts, list):
+ stmts.extend(eqn_stmts)
+ else:
+ stmts.append(eqn_stmts)
+
+ # Generate return statement
+ if len(jaxpr.outvars) == 1:
+ returns = state.name(jaxpr.outvars[0])
+ else:
+ returns = ast.Tuple(elts=[state.name(var) for var in jaxpr.outvars], ctx=ast.Load())
+ stmts.append(ast.Return(value=returns))
+
+ return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[])
+
+
+def constant_fold_jaxpr(jaxpr: jax.core.Jaxpr):
+ """
+ Given a jaxpr, return a new jaxpr with all constant folding done.
+ """
+ return partial_eval_jaxpr(jaxpr, {})
+
+
+def partial_eval_jaxpr(jaxpr, env):
+ env = env.copy()
+ new_eqns = []
+
+ def read(var):
+ if isinstance(var, Literal):
+ return var.val
+ else:
+ return env.get(var, None)
+
+ def read_or_self(var):
+ out = read(var)
+ if out is None:
+ return var
+ elif isinstance(out, Var):
+ return out
+ elif isinstance(out, Literal):
+ return Literal(out.val, var.aval)
+ else:
+ assert not isinstance(out, Jaxpr)
+ return Literal(out, var.aval)
+
+ for eqn in jaxpr.eqns:
+ vals = [read(var) for var in eqn.invars]
+ if eqn.primitive.name in constant_fold_blacklist:
+ new_eqns.append(eqn)
+ elif all(val is not None for val in vals):
+ # go ahead and eval it
+ out = _eval_eqn(eqn, vals)
+
+ # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values
+ if isinstance(out, Jaxpr):
+ # we need to inline this
+ new_eqns.extend(out.eqns)
+ out = out.outvars
+ elif not isinstance(out, tuple) and not isinstance(out, list):
+ out = (out,)
+
+ for var, val in zip(eqn.outvars, out):
+ assert not isinstance(val, Jaxpr)
+ if isinstance(val, Literal):
+ env[var] = val.val
+ else:
+ env[var] = val
+ else:
+ new_eqns.append(eqn)
+
+ # now that we've evaled everything, inline all the constants
+ out_eqns = []
+ for eqn in new_eqns:
+ eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars))
+ out_eqns.append(eqn)
+
+ invars_still_used = IdentitySet()
+ for eqn in out_eqns:
+ for var in eqn.invars:
+ invars_still_used.add(var)
+
+ invars = tuple(var for var in jaxpr.invars if var in invars_still_used)
+
+ # sub in any constants for outvars
+ outvars = tuple(read_or_self(var) for var in jaxpr.outvars)
+
+ return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars)
+
+
+def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jnp.ndarray]:
+ if eqn.primitive.name == "closed_call":
+ assert eqn.primitive.call_primitive == True
+ assert eqn.primitive.map_primitive == False
+
+ out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr,
+ {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)})
+ elif eqn.primitive.name == "scan":
+ out = eqn.primitive.bind(*vals, **eqn.params)
+ else:
+ out = eqn.primitive.bind(*vals, **eqn.params)
+ return out
+
+
+@register_prim_as('dot_general')
+def _astify_dot_general(state, eqn):
+ x, y = eqn.invars
+ d = eqn.params['dimension_numbers']
+ precision = eqn.params['precision']
+ preferred_element_type = eqn.params['preferred_element_type']
+
+ has_dtype = preferred_element_type is None or x.aval.dtype == y.aval.dtype == preferred_element_type
+
+ # recognize simple matmul case
+ if d == (((1,), (0,)), ((), ())) and precision == None:
+ invars = [_astify_atom(state, x), _astify_atom(state, y)]
+ outvars = _astify_outvars(state, eqn.outvars)
+ out = ast.Assign(targets=outvars, value=ast.Call(
+ func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='matmul', ctx=ast.Load()), args=invars,
+ keywords=[]))
+ if not has_dtype:
+ out = ast.Assign(targets=outvars,
+ value=ast.Call(func=ast.Attribute(value=out.value, attr='astype', ctx=ast.Load()),
+ args=[_astify_value(preferred_element_type)], keywords=[]))
+
+ return out
+
+ # TODO: convert to einsum?
+
+ invars = [_astify_atom(state, x),
+ _astify_atom(state, y),
+ _astify_value(d),
+ _astify_value(precision),
+ _astify_value(preferred_element_type)]
+ outvars = _astify_outvars(state, eqn.outvars)
+ return ast.Assign(
+ targets=outvars,
+ value=ast.Call(
+ func=ast.Attribute(value=ast.Name(id='jax.lax', ctx=ast.Load()), attr='dot_general', ctx=ast.Load()),
+ args=invars,
+ keywords=[]
+ )
+ )
+
+
+@register_prim_as('dynamic_slice')
+def _sourcify_dynamic_slice(state, eqn):
+ sliced = eqn.invars[0]
+ invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load())
+ outvars = _astify_outvars(state, eqn.outvars)
+ params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()]
+ return ast.Assign(
+ targets=outvars,
+ value=ast.Call(
+ func=ast.Attribute(
+ value=ast.Name(id='jax.lax', ctx=ast.Load()),
+ attr='dynamic_slice',
+ ctx=ast.Load()
+ ),
+ args=[_astify_atom(state, sliced), invars],
+ keywords=params
+ )
+ )
+
+
+@register_prim_as('slice')
+def _sourcify_slice(state, eqn):
+ sliced = eqn.invars[0]
+ # invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load())
+ outvars = _astify_outvars(state, eqn.outvars)
+ start_indices = eqn.params['start_indices']
+ limit_indices = eqn.params['limit_indices']
+ strides = eqn.params['strides']
+ if strides is None:
+ strides = (None,) * len(start_indices)
+ indices = [_astify_value(slice(s, e, stride))
+ for s, e, stride in zip(start_indices, limit_indices, strides)]
+ # params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()]
+ return ast.Assign(
+ targets=outvars,
+ value=ast.Subscript(
+ value=_astify_atom(state, sliced),
+ slice=ast.Tuple(elts=indices, ctx=ast.Load()),
+ ctx=ast.Load()
+ )
+ )
+
+
+@register_prim_as('dynamic_update_slice')
+def _sourcify_dynamic_update_slice(state, eqn):
+ sliced = eqn.invars[0]
+ # the first two arguments are the sliced array and the update array
+ # the remaining are start indices and should be packaged into a tuple
+ target = _astify_atom(state, eqn.invars[0])
+ update = _astify_atom(state, eqn.invars[1])
+ start_indices = maybe_tuple_vars([_astify_atom(state, var) for var in eqn.invars[2:]])
+ outvars = _astify_outvars(state, eqn.outvars)
+
+ return ast.Assign(targets=outvars, value=ast.Call(
+ func=ast.Attribute(
+ value=ast.Name(id='jax.lax', ctx=ast.Load()),
+ attr='dynamic_update_slice',
+ ctx=ast.Load()
+ ),
+ args=[target, update, start_indices],
+ keywords=[]
+ ))
+
+
+@register_prim_as('convert_element_type')
+def _astify_convert_element_type(state, eqn):
+ # now we use ast
+ outvars = _astify_outvars(state, eqn.outvars)
+ assert len(eqn.invars) == 1
+ invar = _astify_atom(state, eqn.invars[0])
+ dtype = _astify_value(eqn.params['new_dtype'])
+ return ast.Assign(targets=outvars, value=ast.Call(
+ func=ast.Attribute(
+ value=invar,
+ attr='astype',
+ ctx=ast.Load()
+ ),
+ args=[dtype],
+ keywords=[]
+ ))
+
+
+def is_array(arr):
+ return isinstance(arr, (np.ndarray, np.generic, jnp.ndarray))
+
+
+def _astify_array(value):
+ assert is_array(value)
+ if isinstance(value, np.int64):
+ return ast.Constant(value=int(value))
+
+ if value.ndim == 0 and value.dtype in (jnp.float32, jnp.int32, jnp.bool_, jnp.int64):
+ return ast.Constant(value=value.item())
+
+ if value.ndim == 0:
+ dtype_value = _astify_value(value.dtype)
+ return ast.Call(
+ dtype_value,
+ args=[ast.Constant(value=value.item())],
+ keywords=[],
+ )
+
+ values = value.tolist()
+
+ def rec_astify_list(values):
+ if isinstance(values, list):
+ return ast.List(elts=[rec_astify_list(val) for val in values], ctx=ast.Load())
+ else:
+ return ast.Constant(value=values)
+
+ return ast.Call(
+ func=ast.Attribute(
+ value=ast.Name(id='jax.numpy', ctx=ast.Load()),
+ attr='array',
+ ctx=ast.Load()
+ ),
+ args=[rec_astify_list(values)],
+ keywords=[ast.keyword(arg='dtype',
+ value=_astify_value(value.dtype))]
+ )
+
+
+def _astify_atom(state: SourcerorState, var: Union[Literal, Var]):
+ if isinstance(var, Literal):
+ return _astify_value(var.val)
+ elif isinstance(var, Var):
+ return state.name(var)
+ else:
+ raise NotImplementedError()
+
+
+def _astify_value(value):
+ assert not isinstance(value, (Literal, Var))
+
+ if is_array(value):
+ return _astify_array(value)
+ elif isinstance(value, (int, bool, float, str, type(None))):
+ return ast.Constant(value=value)
+ elif isinstance(value, (tuple, list)):
+ return ast.Tuple(elts=[_astify_value(v) for v in value], ctx=ast.Load())
+ elif isinstance(value, jnp.dtype):
+ # return ast.Call(func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[])
+ if value.name in ('float32', 'float64', 'int32', 'int64', 'bfloat16', 'float16'):
+ # return ast.Constant(value=getattr(jnp, value.name))
+ return ast.Attribute(
+ value=ast.Name(id='jax.numpy', ctx=ast.Load()),
+ attr=value.name,
+ ctx=ast.Load()
+ )
+ elif value.name == 'bool':
+ return ast.Attribute(
+ value=ast.Name(id='jax.numpy', ctx=ast.Load()),
+ attr='bool_',
+ ctx=ast.Load()
+ )
+ else:
+ return ast.Call(
+ func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()),
+ attr='dtype',
+ ctx=ast.Load()),
+ args=[ast.Constant(value=str(value))],
+ keywords=[]
+ )
+ elif value is UNSPECIFIED:
+ prefix_imports.add('from jax._src.sharding_impls import UNSPECIFIED')
+ return ast.Name(id='UNSPECIFIED', ctx=ast.Load())
+ elif isinstance(value, enum.Enum):
+ return ast.Attribute(
+ value=ast.Name(id=value.__class__.__qualname__, ctx=ast.Load()),
+ attr=value.name,
+ ctx=ast.Load()
+ )
+
+ else:
+ warnings.warn(f"Unknown value type {type(value)}")
+ return ast.parse(repr(value)).body[0]
+
+
+def _astify_outvars(state, outvars):
+ out = [state.name(v, ctx=ast.Store()) for v in outvars]
+ if len(out) == 1:
+ return out
+ else:
+ return [ast.Tuple(elts=out, ctx=ast.Store())]
+
+
+def maybe_tuple_vars(vars):
+ if len(vars) == 1:
+ return vars[0]
+ else:
+ return ast.Tuple(elts=vars, ctx=ast.Load())
+
+
+def maybe_untuple_vars(var, is_tuple):
+ if is_tuple:
+ return ast.Starred(value=var, ctx=ast.Load())
+ else:
+ return var
+
+
+@register_prim_as('scan')
+def _astify_scan(state, eqn):
+ assert eqn.primitive.name == 'scan'
+
+ # the args to scan are [constants, carry, xs]
+ # constants aren't exposed in the Python API, so we need to handle them specially (we use a lambda)
+ num_consts = eqn.params['num_consts']
+ num_carry = eqn.params['num_carry']
+
+ # TODO: bring back map
+ # if num_carry == 0:
+ # this is a map
+ # return _astify_map(eqn)
+
+ constant_args = eqn.invars[:num_consts]
+ carries = eqn.invars[num_consts:num_consts + num_carry]
+ xs = eqn.invars[num_consts + num_carry:]
+
+ jaxpr = eqn.params['jaxpr']
+
+ if num_consts != 0:
+ # we want to construct an environment where we partial eval the function using the constants as the env
+ env = dict(zip(jaxpr.jaxpr.invars, constant_args))
+ jaxpr = partial_eval_jaxpr(jaxpr.jaxpr, env)
+ else:
+ jaxpr = constant_fold_jaxpr(jaxpr.jaxpr)
+
+ fn_name = state.skolem('fn')
+ fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name)
+
+ length = _astify_value(eqn.params['length'])
+ unroll = _astify_value(eqn.params['unroll'])
+ reverse = _astify_value(eqn.params['reverse'])
+
+ stmts = []
+
+ if num_carry != 1 or len(jaxpr.invars) != 2:
+ # what we want is something like:
+ # fn_name = lambda carry, xs: fn_name(*carry, *xs)
+ # jax.lax.scan(fn_name, (carries...), (xs...))
+
+ modified_signature = ast.arguments(
+ args=[ast.arg(arg='carry'), ast.arg(arg='x')],
+ vararg=None,
+ kwonlyargs=[],
+ kw_defaults=[],
+ kwarg=None,
+ defaults=[],
+ posonlyargs=[]
+ )
+
+ initial_assign = ast.Assign(
+ targets=[ast.Tuple(elts=[ast.Name(a.arg) for a in fn_ast.args.args],
+ ctx=ast.Store())],
+ value=ast.Tuple(
+ elts=[maybe_untuple_vars(ast.Name(id='carry', ctx=ast.Load()), num_carry != 1),
+ maybe_untuple_vars(ast.Name(id='x', ctx=ast.Load()), len(xs) != 1)]
+ )
+ )
+
+ fn_return = fn_ast.body[-1]
+ assert isinstance(fn_return, ast.Return)
+
+ fn_return_value = fn_return.value
+
+ if isinstance(fn_return_value, ast.Tuple):
+ fn_return_value = fn_return_value.elts
+ ret_carries = maybe_tuple_vars(fn_return_value[:num_carry])
+ ret_ys = maybe_tuple_vars(fn_return_value[num_carry:])
+ elif num_carry == 0:
+ ret_carries = _astify_value(())
+ ret_ys = fn_return_value
+ else:
+ ret_carries = fn_return_value
+ ret_ys = _astify_value(())
+
+ scan_return = ast.Return(
+ value=ast.Tuple(elts=[ret_carries, ret_ys], ctx=ast.Load())
+ )
+
+ new_body = [initial_assign] + list(fn_ast.body[:-1]) + [scan_return]
+
+ fn_ast = ast.FunctionDef(
+ name=fn_name,
+ args=modified_signature,
+ body=new_body,
+ decorator_list=[]
+ )
+
+ stmts.append(fn_ast)
+
+ scan_call = ast.Assign(
+ # targets=_astify_outvars(eqn.outvars),
+ targets=[
+ ast.Tuple(
+ elts=[ast.Name(id='final_carry', ctx=ast.Store()),
+ ast.Name(id='ys', ctx=ast.Store())],
+ ctx=ast.Store()
+ )
+ ],
+ value=ast.Call(
+ func=ast.Name(id='jax.lax.scan', ctx=ast.Load()),
+ args=[ast.Name(id=fn_name, ctx=ast.Load()),
+ maybe_tuple_vars([_astify_atom(state, v) for v in carries]),
+ maybe_tuple_vars([_astify_atom(state, v) for v in xs])],
+ keywords=[ast.keyword(arg='length', value=length),
+ ast.keyword(arg='unroll', value=unroll),
+ ast.keyword(arg='reverse', value=reverse)]
+ )
+ )
+ stmts.append(scan_call)
+
+ if num_carry > 0:
+ assign_carry = ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars[:num_carry]),
+ value=ast.Name(id='final_carry', ctx=ast.Load())
+ )
+
+ stmts.append(assign_carry)
+
+ if num_carry < len(eqn.outvars):
+ assign_ys = ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars[num_carry:]),
+ value=ast.Name(id='ys', ctx=ast.Load())
+ )
+
+ stmts.append(assign_ys)
+ else:
+ stmts.append(fn_ast)
+
+ scan_call = ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars),
+ value=ast.Call(
+ func=ast.Name(id='jax.lax.scan', ctx=ast.Load()),
+ args=[ast.Name(id=fn_name, ctx=ast.Load())] + [_astify_atom(state, v) for v in eqn.invars],
+ keywords=[ast.keyword(arg='length', value=length),
+ ast.keyword(arg='unroll', value=unroll),
+ ast.keyword(arg='reverse', value=reverse)]
+ )
+ )
+
+ stmts.append(scan_call)
+
+ return stmts
+
+
+def _astify_map(state, eqn):
+ assert eqn.primitive.name == 'scan'
+ assert eqn.params['num_carry'] == 0
+
+ jaxpr = eqn.params['jaxpr']
+ jaxpr = constant_fold_jaxpr(jaxpr.jaxpr)
+
+ fn_name = state.skolem('fn')
+ fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name)
+
+ # map is a bit funny, because the jaxpr takes K args, but the jax.lax.map function takes a single tuple arg
+ # so we need to use a lambda to redirect the call
+ lam = ast.parse(f"lambda args: {fn_name}(*args)").body[0]
+
+ assign = ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars),
+ value=ast.Call(
+ func=ast.Name(id='jax.lax.map', ctx=ast.Load()),
+ args=[lam,
+ ast.Tuple(elts=[_astify_atom(state, v) for v in eqn.invars],
+ ctx=ast.Load())],
+ keywords=[]
+ )
+ )
+
+ return [fn_ast, assign]
+
+
+@register_prim_as('closed_call')
+def _astify_closed_call(state, eqn):
+ # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr,
+ # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)})
+ raw_jaxpr = eqn.params['call_jaxpr'].jaxpr
+ literal_args = {k: v.val
+ for k, v in zip(raw_jaxpr.invars, eqn.invars)
+ if isinstance(v, Literal)}
+ call_japr = partial_eval_jaxpr(raw_jaxpr, literal_args)
+ fn_name = state.skolem('fn')
+
+ fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name)
+
+ invars = [_astify_atom(state, v)
+ for v in eqn.invars
+ if not isinstance(v, Literal)]
+ outvars = _astify_outvars(state, eqn.outvars)
+
+ assign = ast.Assign(
+ targets=outvars,
+ value=ast.Call(
+ func=ast.Name(id=fn_name, ctx=ast.Load()),
+ args=invars,
+ keywords=[]
+ )
+ )
+
+ return [fn_ast, assign]
+
+
+@register_prim_as('pjit')
+def _astify_pjit(state, eqn):
+ # this one's a real pain.
+ # pjit's params are :
+ # jaxpr
+ # donated_invars:
+ # in_shardings, out_shardings
+ # resource env
+ # name (yay)
+ # keep_unused, inline (which we won't use)
+
+ jaxpr = eqn.params['jaxpr']
+ donated_invars = eqn.params['donated_invars']
+ in_shardings = eqn.params['in_shardings']
+ out_shardings = eqn.params['out_shardings']
+ resource_env = eqn.params['resource_env']
+ name = eqn.params['name']
+
+ can_ignore_donated = not any(donated_invars)
+
+ # preprocess the function
+ jaxpr = constant_fold_jaxpr(jaxpr.jaxpr)
+ fn_name = state.skolem(name)
+ fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name)
+
+ in_shardings = _astify_value(in_shardings)
+ out_shardings = _astify_value(out_shardings)
+
+ keywords = [
+ ast.keyword(arg='in_shardings', value=in_shardings),
+ ast.keyword(arg='out_shardings', value=out_shardings),
+ ]
+
+ if not can_ignore_donated:
+ donated_invars = _astify_value(donated_invars)
+ keywords.append(ast.keyword(arg='donated_invars', value=donated_invars))
+
+ jitted_fn = ast.Call(
+ func=ast.Attribute(
+ ast.Name(id='jax', ctx=ast.Load()),
+ attr='jit'
+ ),
+ args=[ast.Name(id=fn_name, ctx=ast.Load())],
+ keywords=keywords
+ )
+
+ assign = ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars),
+ value=ast.Call(
+ func=jitted_fn,
+ args=[_astify_atom(state, v) for v in eqn.invars],
+ keywords=[]
+ )
+ )
+
+ return [fn_ast, assign]
+
+
+@register_prim_as('remat2')
+def _astify_remat(state: SourcerorState, eqn):
+ # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr,
+ # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)})
+ call_japr = constant_fold_jaxpr(eqn.params['jaxpr'])
+ fn_name = state.skolem('fn')
+
+ fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name)
+
+ invars = [_astify_atom(state, v) for v in eqn.invars]
+ outvars = _astify_outvars(state, eqn.outvars)
+
+ lam = ast.Assign(
+ targets=[ast.Name(id=f"ckpt_{fn_name}", ctx=ast.Store())],
+ # value=ast.parse(f"jax.checkpoint({fn_name})").body[0]
+ value=ast.Call(
+ func=ast.Name(id='jax.checkpoint', ctx=ast.Load()),
+ args=[ast.Name(id=fn_name, ctx=ast.Load())],
+ keywords=[])
+ )
+
+ assign = ast.Assign(
+ targets=outvars,
+ value=ast.Call(
+ func=ast.Name(id=f"ckpt_{fn_name}"),
+ args=invars,
+ keywords=[]
+ ))
+
+ return [fn_ast, lam, assign]
+
+
+@register_prim_as('reshape')
+def _astify_reshape(state, eqn):
+ # the lax reshape is a bit different, because it can combine a transpose and reshape into one.
+ # np.reshape(np.transpose(operand, dimensions), new_sizes)
+ dimensions = eqn.params['dimensions']
+ new_sizes = eqn.params['new_sizes']
+
+ source = _astify_atom(state, eqn.invars[0])
+
+ if dimensions is not None:
+ source = ast.Call(
+ func=ast.Name(id='jax.numpy.transpose', ctx=ast.Load()),
+ args=[source, _astify_value(dimensions)],
+ keywords=[]
+ )
+
+ assign = ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars),
+ value=ast.Call(
+ func=ast.Name(id='jax.numpy.reshape', ctx=ast.Load()),
+ args=[source, _astify_value(new_sizes)],
+ keywords=[]
+ ))
+
+ return [assign]
+
+
+@register_prim_as('add_any')
+def _astify_add_any(state, eqn):
+ # add_any is a weird undocumented jax primitive. best guess is it adds?
+ return _binop_fn(ast.Add())(state, eqn)
+
+
+@register_prim_as('broadcast_in_dim')
+def _astify_broadcast_in_dim(state, eqn):
+ # broadcast_in_dim is how zeros, ones, full, etc are implemented,
+ # so we prefer to use those where possible
+ assert len(eqn.invars) == 1
+ value = eqn.invars[0]
+ shape = eqn.params['shape']
+ broadcast_dimensions = eqn.params['broadcast_dimensions']
+
+ if not isinstance(value, Literal) or broadcast_dimensions != ():
+ return normal_fn('jax.lax.broadcast_in_dim')(state, eqn)
+
+ if not isinstance(value.val, np.ndarray) or value.val.ndim != 0:
+ return normal_fn('jax.lax.broadcast_in_dim')(state, eqn)
+ else:
+ constant_value = value.val.item()
+ if constant_value == 0:
+ call = ast.Call(
+ ast.Attribute(
+ value=ast.Name(id='jax.numpy', ctx=ast.Load()),
+ attr='zeros',
+ ctx=ast.Load()
+ ),
+ args=[_astify_value(shape),
+ _astify_value(value.val.dtype)],
+ keywords=[]
+ )
+ elif constant_value == 1:
+ call = ast.Call(
+ ast.Attribute(
+ value=ast.Name(id='jax.numpy', ctx=ast.Load()),
+ attr='ones',
+ ctx=ast.Load()
+ ),
+ args=[_astify_value(shape),
+ _astify_value(value.val.dtype)],
+ keywords=[]
+ )
+ else:
+ call = ast.Call(
+ ast.Attribute(
+ value=ast.Name(id='jax.numpy', ctx=ast.Load()),
+ attr='full',
+ ctx=ast.Load()
+ ),
+ args=[_astify_value(shape),
+ _astify_value(constant_value),
+ _astify_value(value.val.dtype)],
+ keywords=[]
+ )
+
+ return [ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars),
+ value=call
+ )]
+
+
+@register_prim_as('random_wrap')
+def _astify_random_wrap(state, eqn):
+ # we treat this as a noop
+ return ast.Assign(
+ targets=_astify_outvars(state, eqn.outvars),
+ value=_astify_atom(state, eqn.invars[0])
+ )
+
+
+constant_fold_blacklist = {
+ 'broadcast_in_dim',
+ 'broadcast',
+}
diff --git a/brainpy/_src/integrators/base.py b/brainpy/_src/integrators/base.py
index 6168ffd87..7853123bc 100644
--- a/brainpy/_src/integrators/base.py
+++ b/brainpy/_src/integrators/base.py
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
-from typing import Dict, Sequence, Union
+from typing import Dict, Sequence, Union, Callable
+
+import jax
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy._src.math import TimeDelay, LengthDelay
@@ -9,6 +11,9 @@
from brainpy.errors import DiffEqError
from .constants import DT
+from ._jaxpr_to_source_code import jaxpr_to_python_code
+from contextlib import contextmanager
+
__all__ = [
'Integrator',
]
@@ -58,6 +63,9 @@ def __init__(
self._state_delays[key] = delay
self.register_implicit_nodes(self._state_delays)
+ # math expression
+ self._math_expr = None
+
@property
def dt(self):
"""The numerical integration precision."""
@@ -119,6 +127,18 @@ def state_delays(self):
def state_delays(self, value):
raise ValueError('Cannot set "state_delays" by users.')
+ def _call_integral(self, *args, **kwargs):
+ if _during_compile:
+ jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs)
+ outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs))
+ _, tree = jax.tree.flatten(out_shapes)
+ new_vars = tree.unflatten(outs)
+ self._math_expr = jaxpr_to_python_code(jaxpr.jaxpr)
+
+ else:
+ new_vars = self.integral(**kwargs)
+ return new_vars
+
def __call__(self, *args, **kwargs):
assert self.integral is not None, 'Please build the integrator first.'
@@ -127,7 +147,9 @@ def __call__(self, *args, **kwargs):
kwargs[self.arg_names[i]] = arg
# integral
- new_vars = self.integral(**kwargs)
+ new_vars = self._call_integral(**kwargs)
+
+ # post-process
if len(self.variables) == 1:
dict_vars = {self.variables[0]: new_vars}
else:
@@ -146,3 +168,31 @@ def __call__(self, *args, **kwargs):
f'While we got {delay}')
return new_vars
+
+ def to_math_expr(self):
+ if self._math_expr is None:
+ raise ValueError('Please call ``brainpy.integrators.compile_integrators`` first.')
+ return self._math_expr
+
+
+_during_compile = False
+
+
+@contextmanager
+def _during_compile_context():
+ global _during_compile
+ try:
+ _during_compile = True
+ yield
+ finally:
+ _during_compile = False
+
+
+def compile_integrators(f: Callable, *args, **kwargs):
+ """
+ Compile integrators in the given function.
+ """
+ with _during_compile_context():
+ return f(*args, **kwargs)
+
+
diff --git a/brainpy/_src/integrators/ode/base.py b/brainpy/_src/integrators/ode/base.py
index b34dd4bf4..36b0c5f04 100644
--- a/brainpy/_src/integrators/ode/base.py
+++ b/brainpy/_src/integrators/ode/base.py
@@ -111,7 +111,7 @@ def __call__(self, *args, **kwargs):
kwargs[self.arg_names[i]] = arg
# integral
- new_vars = self.integral(**kwargs)
+ new_vars = self._call_integral(**kwargs)
if len(self.variables) == 1:
dict_vars = {self.variables[0]: new_vars}
else:
diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
index 42ad7f487..d257454ef 100644
--- a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
+++ b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
@@ -94,8 +94,8 @@ def dV(self, V, t, h, n, Iext):
return dVdt
- def update(self, tdi):
- t, dt = tdi.t, tdi.dt
+ def update(self):
+ t, dt = bp.share['t'], bp.share['dt']
V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py
index 11dd42f58..35557b602 100644
--- a/brainpy/_src/integrators/runner.py
+++ b/brainpy/_src/integrators/runner.py
@@ -9,7 +9,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
-from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten
from brainpy import math as bm
@@ -245,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):
# progress bar
if self.progress_bar:
- id_tap(lambda *args: self._pbar.update(), ())
+ jax.debug.callback(lambda *args: self._pbar.update(), ())
# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
diff --git a/brainpy/_src/integrators/tests/test_to_math_expr.py b/brainpy/_src/integrators/tests/test_to_math_expr.py
new file mode 100644
index 000000000..aecf83230
--- /dev/null
+++ b/brainpy/_src/integrators/tests/test_to_math_expr.py
@@ -0,0 +1,48 @@
+# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import brainpy as bp
+
+
+class EINet3(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
+ self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=4000, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.N)
+ self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=4000, tau=10.),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.N)
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(spk[:3200])
+ self.I(spk[3200:])
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+
+def test1():
+ model = EINet3()
+
+ bp.integrators.compile_integrators(model.step_run, 0, 0.)
+ for intg in model.nodes().subset(bp.Integrator).values():
+ print(intg.to_math_expr())
diff --git a/brainpy/_src/losses/comparison.py b/brainpy/_src/losses/comparison.py
index ad0c3ea35..59074eb7b 100644
--- a/brainpy/_src/losses/comparison.py
+++ b/brainpy/_src/losses/comparison.py
@@ -376,7 +376,8 @@ def update(self, input, target):
def nll_loss(input, target, reduction: str = 'mean'):
- r"""The negative log likelihood loss.
+ r"""
+ The negative log likelihood loss.
The negative log likelihood loss. It is useful to train a classification
problem with `C` classes.
diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py
index de559de56..a28ba7d84 100644
--- a/brainpy/_src/math/__init__.py
+++ b/brainpy/_src/math/__init__.py
@@ -44,10 +44,9 @@
from .compat_numpy import *
from .compat_tensorflow import *
from .others import *
-from . import random, linalg, fft, tifunc
+from . import random, linalg, fft
# operators
-from .op_register import *
from .pre_syn_post import *
from . import surrogate, event, sparse, jitconn
diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py
index 0eb391458..746cbe0e9 100644
--- a/brainpy/_src/math/compat_numpy.py
+++ b/brainpy/_src/math/compat_numpy.py
@@ -10,7 +10,6 @@
from .interoperability import *
from .ndarray import Array
-
__all__ = [
'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu',
'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like',
@@ -92,9 +91,8 @@
'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete',
# unique
- 'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'array2string', 'asanyarray',
- 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'disp', 'genfromtxt',
- 'loadtxt', 'info', 'issubclass_', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
+ 'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt',
+ 'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat',
]
@@ -204,11 +202,12 @@ def ascontiguousarray(a, dtype=None, order=None):
return asarray(a, dtype=dtype, order=order)
-def asfarray(a, dtype=np.float_):
+def asfarray(a, dtype=None):
if not np.issubdtype(dtype, np.inexact):
- dtype = np.float_
+ dtype = np.float64
return asarray(a, dtype=dtype)
+
def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
del assume_unique
ar1_flat = ravel(ar1)
@@ -227,6 +226,7 @@ def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
else:
return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1))
+
# Others
# ------
meshgrid = _compatible_with_brainpy_array(jnp.meshgrid)
@@ -454,7 +454,6 @@ def msort(a):
sometrue = any
-
def shape(a):
"""
Return the shape of an array.
@@ -648,7 +647,6 @@ def size(a, axis=None):
finfo = jnp.finfo
iinfo = jnp.iinfo
-
can_cast = _compatible_with_brainpy_array(jnp.can_cast)
choose = _compatible_with_brainpy_array(jnp.choose)
copy = _compatible_with_brainpy_array(jnp.copy)
@@ -678,23 +676,6 @@ def size(a, axis=None):
# Unique APIs
# -----------
-add_docstring = np.add_docstring
-add_newdoc = np.add_newdoc
-add_newdoc_ufunc = np.add_newdoc_ufunc
-
-
-def array2string(a, max_line_width=None, precision=None,
- suppress_small=None, separator=' ', prefix="",
- style=np._NoValue, formatter=None, threshold=None,
- edgeitems=None, sign=None, floatmode=None, suffix="",
- legacy=None):
- a = as_numpy(a)
- return array2string(a, max_line_width=max_line_width, precision=precision,
- suppress_small=suppress_small, separator=separator, prefix=prefix,
- style=style, formatter=formatter, threshold=threshold,
- edgeitems=edgeitems, sign=sign, floatmode=floatmode, suffix=suffix,
- legacy=legacy)
-
def asscalar(a):
return a.item()
@@ -731,13 +712,9 @@ def common_type(*arrays):
return array_type[0][precision]
-disp = np.disp
-
genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs))
loadtxt = lambda *args, **kwargs: asarray(np.loadtxt(*args, **kwargs))
-
info = np.info
-issubclass_ = np.issubclass_
def place(arr, mask, vals):
diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py
index eab8b9b66..f82a90ad7 100644
--- a/brainpy/_src/math/defaults.py
+++ b/brainpy/_src/math/defaults.py
@@ -1,13 +1,10 @@
import jax.numpy as jnp
from jax import config
-from brainpy._src.dependency_check import import_taichi
from .modes import NonBatchingMode
from .scales import IdScaling
-__all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_']
-
-ti = import_taichi(error_if_not_found=False)
+__all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'float_', 'complex_']
# Default computation mode.
mode = NonBatchingMode()
@@ -36,16 +33,3 @@
# default return array type
numpy_func_return = 'bp_array' # 'bp_array','jax_array'
-
-
-if ti is not None:
- # Default integer data type in Taichi.
- ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32
-
- # Default float data type in Taichi.
- ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32
-
-else:
- ti_int = None
- ti_float = None
-
diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py
index ebbb8b6a3..984f3137e 100644
--- a/brainpy/_src/math/environment.py
+++ b/brainpy/_src/math/environment.py
@@ -18,9 +18,6 @@
from . import scales
from . import defaults
from .object_transform import naming
-from brainpy._src.dependency_check import import_taichi
-
-ti = import_taichi(error_if_not_found=False)
__all__ = [
# context manage for environment setting
@@ -459,16 +456,10 @@ def set_float(dtype: type):
"""
if dtype in [jnp.float16, 'float16', 'f16']:
defaults.__dict__['float_'] = jnp.float16
- if ti is not None:
- defaults.__dict__['ti_float'] = ti.float16
elif dtype in [jnp.float32, 'float32', 'f32']:
defaults.__dict__['float_'] = jnp.float32
- if ti is not None:
- defaults.__dict__['ti_float'] = ti.float32
elif dtype in [jnp.float64, 'float64', 'f64']:
defaults.__dict__['float_'] = jnp.float64
- if ti is not None:
- defaults.__dict__['ti_float'] = ti.float64
else:
raise NotImplementedError
@@ -494,20 +485,12 @@ def set_int(dtype: type):
"""
if dtype in [jnp.int8, 'int8', 'i8']:
defaults.__dict__['int_'] = jnp.int8
- if ti is not None:
- defaults.__dict__['ti_int'] = ti.int8
elif dtype in [jnp.int16, 'int16', 'i16']:
defaults.__dict__['int_'] = jnp.int16
- if ti is not None:
- defaults.__dict__['ti_int'] = ti.int16
elif dtype in [jnp.int32, 'int32', 'i32']:
defaults.__dict__['int_'] = jnp.int32
- if ti is not None:
- defaults.__dict__['ti_int'] = ti.int32
elif dtype in [jnp.int64, 'int64', 'i64']:
defaults.__dict__['int_'] = jnp.int64
- if ti is not None:
- defaults.__dict__['ti_int'] = ti.int64
else:
raise NotImplementedError
diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py
index 9ebad3e94..91b479b62 100644
--- a/brainpy/_src/math/event/__init__.py
+++ b/brainpy/_src/math/event/__init__.py
@@ -1,2 +1,3 @@
from .csr_matvec import *
+from .csr_matmat import *
diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py
new file mode 100644
index 000000000..0db589ae1
--- /dev/null
+++ b/brainpy/_src/math/event/csr_matmat.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+
+from typing import Union, Tuple
+
+from jax import numpy as jnp
+
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
+from brainpy._src.math.ndarray import Array
+
+bti = import_braintaichi(error_if_not_found=False)
+
+__all__ = [
+ 'csrmm',
+]
+
+
+def csrmm(
+ data: Union[float, jnp.ndarray, Array],
+ indices: Union[jnp.ndarray, Array],
+ indptr: Union[jnp.ndarray, Array],
+ matrix: Union[jnp.ndarray, Array],
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+):
+ """Product of CSR sparse matrix and a dense event matrix.
+
+ Args:
+ data : array of shape ``(nse,)``, float.
+ indices : array of shape ``(nse,)``
+ indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
+ B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
+ dtype ``data.dtype``
+ shape : length-2 tuple representing the matrix shape
+ transpose : boolean specifying whether to transpose the sparse matrix
+ before computing.
+
+ Returns:
+ C : array of shape ``(shape[1] if transpose else shape[0], cols)``
+ representing the matrix-matrix product product.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+
+ return bti.event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py
index 9890838e7..d9c39370e 100644
--- a/brainpy/_src/math/event/csr_matvec.py
+++ b/brainpy/_src/math/event/csr_matvec.py
@@ -13,23 +13,15 @@
from typing import Union, Tuple
import jax
-import jax.numpy as jnp
-import numpy as np
-from jax.interpreters import ad
-from brainpy._src.dependency_check import import_taichi
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import XLACustomOp
-from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi
-from brainpy._src.math.sparse.utils import csr_to_coo
-from brainpy.errors import PackageMissingError
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
+
+bti = import_braintaichi(error_if_not_found=False)
__all__ = [
- 'csrmv'
+ 'csrmv'
]
-ti = import_taichi(error_if_not_found=False)
-
def csrmv(
data: Union[float, jax.Array],
@@ -40,478 +32,37 @@ def csrmv(
shape: Tuple[int, int],
transpose: bool = False,
) -> jax.Array:
- """Product of a sparse CSR matrix and a dense event vector.
-
- This function supports JAX transformations, including `jit()`, `grad()`,
- `vmap()` and `pmap()`.
-
- Parameters
- ----------
- data: ndarray, float
- An array of shape ``(nse,)``.
- indices: ndarray
- An array of shape ``(nse,)``.
- indptr: ndarray
- An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
- events: ndarray
- An array of shape ``(shape[0] if transpose else shape[1],)``
- and dtype ``data.dtype``.
- shape: tuple
- A length-2 tuple representing the matrix shape.
- transpose: bool
- A boolean specifying whether to transpose the sparse matrix
- before computing.
- If ``transpose=True``, the operator will compute based on the
- event-driven property of the ``events`` vector.
-
- Returns
- -------
- y : Array
- The array of shape ``(shape[1] if transpose else shape[0],)`` representing
- the matrix vector product.
- """
- data = as_jax(data)
- indices = as_jax(indices)
- indptr = as_jax(indptr)
- events = as_jax(events)
-
- # checking
- data = jnp.atleast_1d(data)
- if np.ndim(data) == 1:
- if data.shape[0] not in [1, indices.shape[0]]:
- raise ValueError('The size of data should be 1 or be consistent with indices.'
- f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.')
- else:
- raise ValueError('data should be a scalar or 1D vector. '
- f'But we got {np.ndim(data)}-D array.')
- if np.ndim(indices) != 1:
- raise ValueError('indices should be a 1D vector with integer type.')
- if np.ndim(indptr) != 1:
- raise ValueError('indptr should be a 1D vector with integer type.')
- if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]:
- raise ValueError(
- 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.')
- if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]:
- raise ValueError(
- 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.')
- if np.ndim(events) != 1:
- raise ValueError('events should be a 1D vector.')
- if len(shape) != 2:
- raise ValueError('shape should be a length-2 tuple.')
- if transpose:
- if events.shape[0] != shape[0]:
- raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.')
- else:
- if events.shape[0] != shape[1]:
- raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).')
-
- # if the shape of indices is (0,), then we return a zero vector
- if indices.shape[0] == 0:
- return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)
-
- return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]
-
-
-def raw_csrmv_taichi(
- data: Union[float, jax.Array],
- indices: jax.Array,
- indptr: jax.Array,
- events: jax.Array,
- *,
- shape: Tuple[int, int],
- transpose: bool = False
-):
- if ti is None:
- raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators')
-
- if transpose:
- if events.dtype == jnp.bool_:
- if data.shape[0] == 1:
- prim = _event_csrmv_transpose_bool_homo_p
- else:
- prim = _event_csrmv_transpose_bool_heter_p
- else:
- if data.shape[0] == 1:
- prim = _event_csrmv_transpose_homo_p
- else:
- prim = _event_csrmv_transpose_heter_p
- else:
- if events.dtype == jnp.bool_:
- if data.shape[0] == 1:
- prim = _event_csrmv_bool_homo_p
- else:
- prim = _event_csrmv_bool_heter_p
- else:
- if data.shape[0] == 1:
- prim = _event_csrmv_homo_p
- else:
- prim = _event_csrmv_heter_p
-
- # computing
- return prim(data,
- indices,
- indptr,
- events,
- outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)],
- transpose=transpose,
- shape=shape)
-
-
-if ti is not None:
-
- # -------------
- # CPU operators
- # -------------
-
- # 1. The benchmarking shows that the performance of the following transpose
- # kernels is maximized when using serialized mode
- # 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable
- # arguments, we have to define each kernel separately when the
- # non-differentiable/non-jittable arguments are different.
-
- @ti.kernel
- def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- if events[row_i]:
- for j in range(indptr[row_i], indptr[row_i + 1]):
- out[indices[j]] += value
-
-
- @ti.kernel
- def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- if events[row_i]:
- for j in range(indptr[row_i], indptr[row_i + 1]):
- out[indices[j]] += values[j]
-
-
- @ti.kernel
- def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- if events[row_i] != 0.:
- for j in range(indptr[row_i], indptr[row_i + 1]):
- out[indices[j]] += value
-
-
- @ti.kernel
- def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- if events[row_i] != 0.:
- for j in range(indptr[row_i], indptr[row_i + 1]):
- out[indices[j]] += values[j]
-
-
- @ti.kernel
- def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- # ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- r = 0.
- for j in range(indptr[row_i], indptr[row_i + 1]):
- if events[indices[j]]:
- r += value
- out[row_i] = r
-
-
- @ti.kernel
- def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- # ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- r = 0.
- for j in range(indptr[row_i], indptr[row_i + 1]):
- if events[indices[j]]:
- r += values[j]
- out[row_i] = r
-
-
- @ti.kernel
- def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- # ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- r = 0.
- for j in range(indptr[row_i], indptr[row_i + 1]):
- if events[indices[j]] != 0.:
- r += value
- out[row_i] = r
-
-
- @ti.kernel
- def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- # ti.loop_config(serialize=True)
- for row_i in range(indptr.shape[0] - 1):
- r = 0.
- for j in range(indptr[row_i], indptr[row_i + 1]):
- if events[indices[j]] != 0.:
- r += values[j]
- out[row_i] = r
-
-
- # -------------
- # GPU operators
- # -------------
-
- # 1. GPU kernels are different from the CPU ones, since the GPU kernels need
- # to use warp-level parallelism to achieve the best performance.
-
- @ti.kernel
- def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- if events[row_i]:
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- out[indices[j]] += value
- j += 32
-
-
- @ti.kernel
- def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- if events[row_i] != 0.:
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- out[indices[j]] += value
- j += 32
-
-
- # TODO
- # It is important to note that the following warp-based kernels
- # should be improved, since the atomic_add for each thread is not
- # very efficient. Instead, the warp-level reduction primitive
- # should be used.
- # see ``warp_reduce_sum()`` function in tifunc.py.
- # However, currently Taichi does not support general warp-level primitives.
-
- @ti.kernel
- def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- r = 0.
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- if events[indices[j]]:
- r += value
- j += 32
- out[row_i] += r # TODO: warp-level primitive
-
-
- @ti.kernel
- def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- r = 0.
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- if events[indices[j]] != 0.:
- r += value
- j += 32
- out[row_i] += r # TODO: warp-level primitive
-
-
- @ti.kernel
- def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- if events[row_i]:
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- out[indices[j]] += values[j]
- j += 32
-
-
- @ti.kernel
- def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- if events[row_i] != 0.:
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- out[indices[j]] += values[j]
- j += 32
-
-
- @ti.kernel
- def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- r = 0.
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- if events[indices[j]]:
- r += values[j]
- j += 32
- out[row_i] += r # TODO: warp-level primitive
-
-
- @ti.kernel
- def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
- indices: ti.types.ndarray(ndim=1),
- indptr: ti.types.ndarray(ndim=1),
- events: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- for i in range((indptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- r = 0.
- j = indptr[row_i] + index
- end_index = indptr[row_i + 1]
- while j < end_index:
- if events[indices[j]] != 0.:
- r += values[j]
- j += 32
- out[row_i] += r # TODO: warp-level primitive
-
-
- def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
- return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)
-
-
- def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape):
- return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose)
-
-
- def _event_csr_matvec_transpose_taichi(
- ct, values, indices, indptr, events, *, outs, transpose, shape
- ):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
- if ad.is_undefined_primal(events):
- ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0]
- return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events)
- else:
- if type(ct[0]) is ad.Zero:
- ct_values = ad.Zero(values)
- else:
- if values.aval.shape[0] == 1: # scalar
- ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
- ct_values = jnp.inner(ct[0], ct_values)
- else: # heterogeneous values
- row, col = csr_to_coo(indices, indptr)
- ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
- return ct_values, indices, indptr, events
-
-
- def _define_op(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi)
- prim.def_transpose_rule(_event_csr_matvec_transpose_taichi)
- return prim
-
-
- # transpose bool homo
- _event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu,
- _event_csr_matvec_transpose_bool_homo_gpu)
-
- # transpose homo
- _event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu,
- _event_csr_matvec_transpose_homo_gpu)
-
- # not transpose bool homo
- _event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu,
- _event_csr_matvec_bool_homo_gpu)
-
- # not transpose homo
- _event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu,
- _event_csr_matvec_homo_gpu)
-
- # transpose bool heter
- _event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu,
- _event_csr_matvec_transpose_bool_heter_gpu)
-
- # transpose heter
- _event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu,
- _event_csr_matvec_transpose_heter_gpu)
-
- # not transpose bool heter
- _event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu,
- _event_csr_matvec_bool_heter_gpu)
-
- # not transpose heter
- _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu,
- _event_csr_matvec_heter_gpu)
+ """Product of a sparse CSR matrix and a dense event vector.
+
+ This function supports JAX transformations, including `jit()`, `grad()`,
+ `vmap()` and `pmap()`.
+
+ Parameters
+ ----------
+ data: ndarray, float
+ An array of shape ``(nse,)``.
+ indices: ndarray
+ An array of shape ``(nse,)``.
+ indptr: ndarray
+ An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
+ events: ndarray
+ An array of shape ``(shape[0] if transpose else shape[1],)``
+ and dtype ``data.dtype``.
+ shape: tuple
+ A length-2 tuple representing the matrix shape.
+ transpose: bool
+ A boolean specifying whether to transpose the sparse matrix
+ before computing.
+ If ``transpose=True``, the operator will compute based on the
+ event-driven property of the ``events`` vector.
+
+ Returns
+ -------
+ y : Array
+ The array of shape ``(shape[1] if transpose else shape[0],)`` representing
+ the matrix vector product.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+
+ return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose)
diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py
deleted file mode 100644
index 3ac1e0ee2..000000000
--- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py
+++ /dev/null
@@ -1,254 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('cpu')
-
-s = [1000, 5000, 10000, 20000, 25000, 30000]
-p = [0.1, 0.2, 0.3, 0.4, 0.5]
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
-]
-
-
-
-values_type = [
- 'homo',
- 'heter'
- ]
-events_type = [
- 'bool',
- 'float',
- ]
-transpose = [
- True,
- False
- ]
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-print(bm.get_platform())
-
-@partial(jax.jit, static_argnums=(4, 5))
-def event_csrmv_taichi(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0]
- return r
-
-@partial(jax.jit, static_argnums=(4, 5))
-def event_csrmv(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)
- return r
-
-def test_event_csrmv(shape, values_type, events_type, transpose):
- rng = bm.random.RandomState(seed=1234)
- indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
- vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
- weight = 1.
-
-
- if events_type == 'float':
- vector = vector.astype(bm.float32)
- if values_type == 'heter':
- heter_data = bm.ones(indices.shape) * weight
- weight = heter_data
-
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time0 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time19 = time.time()
-
-
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time20 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
- # assert(jnp.allclose(result1[0], result2))
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/event_csrmv_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/event_csrmv_gpu.csv', index=False)
diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py
deleted file mode 100644
index 98793e600..000000000
--- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('cpu')
-
-s = [1000, 5000, 10000, 20000, 25000, 30000]
-p = [0.1, 0.2, 0.3, 0.4, 0.5]
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
-]
-
-
-
-values_type = [
- 'homo',
- 'heter'
- ]
-events_type = [
- 'bool',
- 'float',
- ]
-transpose = [
- True,
- False
- ]
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-print(bm.get_platform())
-
-def sum_op(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)
- return r.sum()
-
- return func
-
-
-def sum_op2(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)[0]
- return r.sum()
-
- return func
-
-@partial(jax.jit, static_argnums=(4, 5))
-def event_csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)(
- weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- return r
-
-@partial(jax.jit, static_argnums=(4, 5))
-def event_csrmv_grad(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.event.csrmv), argnums=3)(
- weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- return r
-
-
-def test_event_csrmv(shape, values_type, events_type, transpose):
- rng = bm.random.RandomState(seed=1234)
- indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
- vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
- weight = 1.
-
-
- if events_type == 'float':
- vector = vector.astype(bm.float32)
- if values_type == 'heter':
- heter_data = bm.ones(indices.shape) * weight
- weight = heter_data
-
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/event_csrmv_grad_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/event_csrmv_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py
deleted file mode 100644
index 181ee5520..000000000
--- a/brainpy/_src/math/event/tests/test_event_csrmv.py
+++ /dev/null
@@ -1,237 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-from functools import partial
-
-import jax
-import pytest
-from absl.testing import parameterized
-
-import brainpy as bp
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
-
-import platform
-force_test = False # turn on to force test on windows locally
-if platform.system() == 'Windows' and not force_test:
- pytest.skip('skip windows', allow_module_level=True)
-
-
-seed = 1234
-
-
-def sum_op(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)
- return r.sum()
-
- return func
-
-
-class Test_event_csr_matvec_taichi(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs)
-
- print()
- bm.set_platform(platform)
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(100, 200), (10, 1000)],
- homo_data=[1.],
- )
- def test_homo(self, transpose, shape, homo_data):
- print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
-
- homo_data = bm.asarray([homo_data])
-
- rng = bm.random.RandomState(seed)
- indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- heter_data = bm.ones(indices.shape) * homo_data
-
- dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r1 = (events @ dense) if transpose else (dense @ events)
- r2 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
-
- assert (bm.allclose(r1, r2))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(100, 200), (10, 1000)],
- homo_data=[1.],
- )
- def test_homo_vmap(self, shape, transpose, homo_data):
- print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
-
- homo_data = bm.asarray([homo_data])
-
- rng = bm.random.RandomState(seed)
- indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
-
- # vmap 'data'
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events,
- shape=shape, transpose=transpose))
- f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events,
- shape=shape, transpose=transpose))
- vmap_data = bm.as_jax([homo_data] * 10)
- self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))
-
- # vmap 'events'
- f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr,
- shape=shape, transpose=transpose))
- f4 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr,
- shape=shape, transpose=transpose))
- vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
- self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))
-
- # vmap 'data' and 'events'
- f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose))
- f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose))
-
- vmap_data1 = bm.as_jax([homo_data] * 10)
- vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
- self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
- f6(vmap_data1, vmap_data2)))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(100, 200), (10, 1000)],
- homo_data=[1.],
- )
- def test_homo_grad(self, shape, transpose, homo_data):
- print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
-
- homo_data = bm.asarray([homo_data])
-
- rng = bm.random.RandomState(seed)
- indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)
-
- # grad 'data'
- r1 = jax.grad(sum_op(bm.sparse.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
- r2 = jax.grad(sum_op(bm.event.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r2))
-
- # grad 'events'
- r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape,
- transpose=transpose)
- r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape,
- transpose=transpose)
- self.assertTrue(bm.allclose(r3, r4))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(100, 200), (10, 1000), ]
- )
- def test_heter(self, shape, transpose):
- print(f'test_heter: shape = {shape}, transpose = {transpose}')
- rng = bm.random.RandomState(seed)
- indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- heter_data = bm.as_jax(rng.random(indices.shape))
-
- r1 = bm.sparse.csrmv(heter_data, indices, indptr, events,
- shape=shape, transpose=transpose)
- r2 = bm.event.csrmv(heter_data, indices, indptr, events,
- shape=shape, transpose=transpose)
-
- assert (bm.allclose(r1, r2))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(100, 200), (10, 1000)]
- )
- def test_heter_vmap(self, shape, transpose):
- print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}')
-
- rng = bm.random.RandomState(seed)
- indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
-
- # vmap 'data'
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events,
- shape=shape, transpose=transpose))
- f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events,
- shape=shape, transpose=transpose))
- vmap_data = bm.as_jax(rng.random((10, indices.shape[0])))
- self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))
-
- # vmap 'events'
- data = bm.as_jax(rng.random(indices.shape))
- f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr,
- shape=shape, transpose=transpose))
- f4 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr,
- shape=shape, transpose=transpose))
- vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
- self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))
-
- # vmap 'data' and 'events'
- f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee,
- shape=shape, transpose=transpose))
- f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee,
- shape=shape, transpose=transpose))
- vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0])))
- vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
- self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
- f6(vmap_data1, vmap_data2)))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(100, 200), (10, 1000)]
- )
- def test_heter_grad(self, shape, transpose):
- print(f'test_heter_grad: shape = {shape}, transpose = {transpose}')
-
- rng = bm.random.RandomState(seed)
- indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)
-
- # grad 'data'
- data = bm.as_jax(rng.random(indices.shape))
- r1 = jax.grad(sum_op(bm.sparse.csrmv))(
- data, indices, indptr, events, shape=shape, transpose=transpose)
- r2 = jax.grad(sum_op(bm.event.csrmv))(
- data, indices, indptr, events, shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r2))
-
- # grad 'events'
- r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
- data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(
- data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r3, r4))
-
- r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))(
- data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- r6 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))(
- data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r5[0], r6[0]))
- self.assertTrue(bm.allclose(r5[1], r6[1]))
-
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py
index a22aac757..e4a33ce0c 100644
--- a/brainpy/_src/math/jitconn/event_matvec.py
+++ b/brainpy/_src/math/jitconn/event_matvec.py
@@ -3,32 +3,18 @@
from typing import Tuple, Optional
import jax
-import numpy as np
-from jax import numpy as jnp
-from brainpy._src.dependency_check import import_taichi
-from brainpy._src.math.interoperability import as_jax
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
from brainpy._src.math.jitconn.matvec import (mv_prob_homo,
mv_prob_uniform,
- mv_prob_normal,
- _general_checking,
- raw_mv_prob_homo,
- raw_mv_prob_uniform,
- raw_mv_prob_normal,
- _mv_prob_homo_transpose,
- _mv_prob_uniform_transpose,
- _mv_prob_normal_transpose,
- _reverse)
-from brainpy._src.math.ndarray import _get_dtype
-from brainpy._src.math.op_register import XLACustomOp
-from brainpy.errors import PackageMissingError
+ mv_prob_normal)
-ti = import_taichi(error_if_not_found=False)
+bti = import_braintaichi(error_if_not_found=False)
__all__ = [
- 'event_mv_prob_homo',
- 'event_mv_prob_uniform',
- 'event_mv_prob_normal',
+ 'event_mv_prob_homo',
+ 'event_mv_prob_uniform',
+ 'event_mv_prob_normal',
]
@@ -42,23 +28,12 @@ def event_mv_prob_homo(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- if ti is None:
- raise PackageMissingError.by_purpose('taichi', purpose='customized operators')
-
- events = as_jax(events)
- weight = as_jax(weight)
- if jnp.ndim(weight) < 1:
- weight = jnp.expand_dims(weight, axis=0)
- conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
- conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
- if seed is None:
- with jax.ensure_compile_time_eval():
- seed = np.random.randint(0, int(1e8), 1)
- seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
- return raw_event_mv_prob_homo(events, weight, conn_len, seed,
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel)[0]
+ if bti is None:
+ raise_braintaichi_not_found()
+ return bti.jitc_event_mv_prob_homo(events, weight, conn_prob, seed,
+ shape=shape,
+ transpose=transpose,
+ outdim_parallel=outdim_parallel)
event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__
@@ -75,22 +50,10 @@ def event_mv_prob_uniform(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- if ti is None:
- raise PackageMissingError.by_purpose('taichi', purpose='customized operators')
-
- events = as_jax(events)
- if isinstance(w_low, float): w_low = as_jax(w_low)
- if isinstance(w_high, float): w_high = as_jax(w_high)
- w_low = jnp.atleast_1d(as_jax(w_low))
- w_high = jnp.atleast_1d(as_jax(w_high))
- conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
- conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
- if seed is None:
- with jax.ensure_compile_time_eval():
- seed = np.random.randint(0, int(1e8), 1)
- seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
- return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)[0]
+ if bti is None:
+ raise_braintaichi_not_found()
+ return bti.jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__
@@ -107,1054 +70,10 @@ def event_mv_prob_normal(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- if ti is None:
- raise PackageMissingError.by_purpose('taichi', purpose='customized operators')
-
- events = as_jax(events)
- if isinstance(w_mu, float): w_mu = as_jax(w_mu)
- if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma)
- w_mu = jnp.atleast_1d(as_jax(w_mu))
- w_sigma = jnp.atleast_1d(as_jax(w_sigma))
- conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
- conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
- if seed is None:
- with jax.ensure_compile_time_eval():
- seed = np.random.randint(0, int(1e8), 1)
- seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
- return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)[0]
+ if bti is None:
+ raise_braintaichi_not_found()
+ return bti.jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__
-
-if ti is not None:
- from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal)
-
-
- # -------------
- # CPU function
- # -------------
- # For each non-zero event value, it generates a random key using a
- # function lfsr88_key and then uses this key to compute random integers
- # and update the out array based on the computed indices and weight.
- #
- # The function is likely designed to be parallelized.
-
- @ti.kernel
- def _event_mv_prob_homo_bool_cpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- if events[i_col]:
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- out[i_row] += weight0
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_homo_outdim_parallel_bool_cpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- if events[i_col]:
- r += weight0
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r
-
-
- # -------------
- # GPU function
- # -------------
- # Contrary to the CPU functions, for each column,
- # this function will 32 threads (one warp) to make
- # the just-in-time random generation parallelized.
-
- @ti.kernel
- def _event_mv_prob_homo_bool_gpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- if events[i_col]:
- index = i & 31
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- out[i_row] += weight0
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_homo_outdim_parallel_bool_gpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.u32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- index = i & 31
- i_col = step * index - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- r += weight0 * events[i_col] # TODO: speed comparison without if else
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- def _reverse(shape):
- return shape[::-1]
-
-
- # -------------
- # CPU function
- # -------------
- # For each non-zero event value, it generates a random key using a
- # function lfsr88_key and then uses this key to compute random integers
- # and update the out array based on the computed indices and weight.
- #
- # The function is likely designed to be parallelized.
-
- @ti.kernel
- def _event_mv_prob_homo_cpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- if events[i_col] != 0.:
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- out[i_row] += weight0
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_homo_outdim_parallel_cpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- if events[i_col] != 0.:
- r += weight0
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r # TODO: warp-level reduction
-
-
- # -------------
- # GPU function
- # -------------
- # Contrary to the CPU functions, for each column,
- # this function will 32 threads (one warp) to make
- # the just-in-time random generation parallelized.
-
- @ti.kernel
- def _event_mv_prob_homo_gpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- if events[i_col] != 0.:
- index = i & 31
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- out[i_row] += weight0
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_homo_outdim_parallel_gpu(
- events: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- index = i & 31
- i_col = step * index - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- r += weight0 * events[i_col] # TODO: speed comparison with if else
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- def _event_mv_prob_homo_jvp_events(
- evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_homo(evt_dot, weight, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _event_mv_prob_homo_jvp_weight(
- w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_homo(events, w_dot, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
- assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64]
- return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights)
-
-
- def raw_event_mv_prob_homo(
- events: jax.Array,
- weight: jax.Array, # vector with size 1
- conn_len: jax.Array, # vector with size 1
- seed: jax.Array, # vector with size 1
- *,
- shape: Tuple[int, int],
- transpose: bool = False,
- outdim_parallel: bool = True,
- ) -> jax.Array:
- mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight)
-
- if outdim_parallel:
- if events.dtype == jnp.bool_:
- prim = _event_mv_prob_homo_outdim_parallel_bool_p
- else:
- prim = _event_mv_prob_homo_outdim_parallel_p
- else:
- if events.dtype == jnp.bool_:
- prim = _event_mv_prob_homo_bool_p
- else:
- prim = _event_mv_prob_homo_p
-
- return prim(events,
- weight,
- conn_len,
- seed,
- outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)],
- shape=mat_shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
- def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_event_mv_prob_homo_jvp_events,
- _event_mv_prob_homo_jvp_weight,
- None,
- None)
- prim.def_transpose_rule(_mv_prob_homo_transpose)
- return prim
-
-
- # outdim_parallel = True, events.dtype = jnp.bool_
- _event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim(
- cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu,
- gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu
- )
-
- # outdim_parallel = False, events.dtype = jnp.bool_
- _event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim(
- cpu_kernel=_event_mv_prob_homo_bool_cpu,
- gpu_kernel=_event_mv_prob_homo_bool_gpu
- )
-
- # outdim_parallel = True, events.dtype != jnp.bool_
- _event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim(
- cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu,
- gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu
- )
-
- # outdim_parallel = False, events.dtype != jnp.bool_
- _event_mv_prob_homo_p = _define_event_mv_prob_homo_prim(
- cpu_kernel=_event_mv_prob_homo_cpu,
- gpu_kernel=_event_mv_prob_homo_gpu
- )
-
-
- @ti.kernel
- def _event_mv_prob_uniform_bool_cpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- if events[i_col]:
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_uniform_outdim_parallel_bool_cpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- if events[i_col]:
- r += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r
-
-
- @ti.kernel
- def _event_mv_prob_uniform_bool_gpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- if events[i_col]:
- index = i & 31
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_uniform_outdim_parallel_bool_gpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.u32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- index = i & 31
- i_col = step * index - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- r += row_v * events[i_col] # TODO: speed comparison without if else
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- @ti.kernel
- def _event_mv_prob_uniform_cpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- if events[i_col] != 0.:
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_uniform_outdim_parallel_cpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- if events[i_col] != 0.:
- r += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r # TODO: warp-level reduction
-
-
- @ti.kernel
- def _event_mv_prob_uniform_gpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- if events[i_col] != 0.:
- index = i & 31
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_uniform_outdim_parallel_gpu(
- events: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- index = i & 31
- i_col = step * index - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- r += row_v * events[i_col] # TODO: speed comparison with if else
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- def _event_mv_prob_uniform_jvp_events(
- evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _event_mv_prob_uniform_jvp_w_low(
- w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _event_mv_prob_uniform_jvp_w_high(
- w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def raw_event_mv_prob_uniform(
- events: jax.Array,
- w_low: jax.Array, # vector with size 1
- w_high: jax.Array, # vector with size 1
- conn_len: jax.Array, # vector with size 1
- seed: jax.Array, # vector with size 1
- *,
- shape: Tuple[int, int],
- transpose: bool = False,
- outdim_parallel: bool = True,
- ) -> jax.Array:
- mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high)
-
- if outdim_parallel:
- if events.dtype == jnp.bool_:
- prim = _event_mv_prob_uniform_outdim_parallel_bool_p
- else:
- prim = _event_mv_prob_uniform_outdim_parallel_p
- else:
- if events.dtype == jnp.bool_:
- prim = _event_mv_prob_uniform_bool_p
- else:
- prim = _event_mv_prob_uniform_p
-
- return prim(events,
- w_low,
- w_high,
- conn_len,
- seed,
- outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)],
- shape=mat_shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
- def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_event_mv_prob_uniform_jvp_events,
- _event_mv_prob_uniform_jvp_w_low,
- _event_mv_prob_uniform_jvp_w_high,
- None,
- None)
- prim.def_transpose_rule(_mv_prob_uniform_transpose)
- return prim
-
-
- # outdim_parallel = True, events.dtype = jnp.bool_
- _event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim(
- cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu,
- gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu
- )
-
- # outdim_parallel = False, events.dtype = jnp.bool_
- _event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim(
- cpu_kernel=_event_mv_prob_uniform_bool_cpu,
- gpu_kernel=_event_mv_prob_uniform_bool_gpu
- )
-
- # outdim_parallel = True, events.dtype != jnp.bool_
- _event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim(
- cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu,
- gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu
- )
-
- # outdim_parallel = False, events.dtype != jnp.bool_
- _event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim(
- cpu_kernel=_event_mv_prob_uniform_cpu,
- gpu_kernel=_event_mv_prob_uniform_gpu
- )
-
-
- @ti.kernel
- def _event_mv_prob_normal_bool_cpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- if events[i_col]:
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_normal_outdim_parallel_bool_cpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- if events[i_col]:
- r += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r
-
-
- @ti.kernel
- def _event_mv_prob_normal_bool_gpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- if events[i_col]:
- index = i & 31
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_normal_outdim_parallel_bool_gpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.u32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- index = i & 31
- i_col = step * index - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- r += row_v * events[i_col] # TODO: speed comparison without if else
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- @ti.kernel
- def _event_mv_prob_normal_cpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- if events[i_col] != 0.:
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_normal_outdim_parallel_cpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- if events[i_col] != 0.:
- r += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r
-
-
- @ti.kernel
- def _event_mv_prob_normal_gpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- if events[i_col] != 0.:
- index = i & 31
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- out[i_row] += row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _event_mv_prob_normal_outdim_parallel_gpu(
- events: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = events.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- index = i & 31
- i_col = step * index - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- r += row_v * events[i_col] # TODO: speed comparison with if else
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- def _event_mv_prob_normal_jvp_events(
- evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _event_mv_prob_normal_jvp_w_mu(
- w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _event_mv_prob_normal_jvp_w_sigma(
- w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
- ):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def raw_event_mv_prob_normal(
- events: jax.Array,
- w_mu: jax.Array, # vector with size 1
- w_sigma: jax.Array, # vector with size 1
- conn_len: jax.Array, # vector with size 1
- seed: jax.Array, # vector with size 1
- *,
- shape: Tuple[int, int],
- transpose: bool = False,
- outdim_parallel: bool = True,
- ) -> jax.Array:
- mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma)
-
- if outdim_parallel:
- if events.dtype == jnp.bool_:
- prim = _event_mv_prob_normal_outdim_parallel_bool_p
- else:
- prim = _event_mv_prob_normal_outdim_parallel_p
- else:
- if events.dtype == jnp.bool_:
- prim = _event_mv_prob_normal_bool_p
- else:
- prim = _event_mv_prob_normal_p
-
- return prim(events,
- w_mu,
- w_sigma,
- conn_len,
- seed,
- outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)],
- shape=mat_shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
- def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_event_mv_prob_normal_jvp_events,
- _event_mv_prob_normal_jvp_w_mu,
- _event_mv_prob_normal_jvp_w_sigma,
- None,
- None)
- prim.def_transpose_rule(_mv_prob_normal_transpose)
- return prim
-
-
- # outdim_parallel = True, events.dtype = jnp.bool_
- _event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim(
- cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu,
- gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu
- )
-
- # outdim_parallel = False, events.dtype = jnp.bool_
- _event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim(
- cpu_kernel=_event_mv_prob_normal_bool_cpu,
- gpu_kernel=_event_mv_prob_normal_bool_gpu
- )
-
- # outdim_parallel = True, events.dtype != jnp.bool_
- _event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim(
- cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu,
- gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu
- )
-
- # outdim_parallel = False, events.dtype != jnp.bool_
- _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim(
- cpu_kernel=_event_mv_prob_normal_cpu,
- gpu_kernel=_event_mv_prob_normal_gpu
- )
\ No newline at end of file
diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py
index 00e5778f9..4481e6fd6 100644
--- a/brainpy/_src/math/jitconn/matvec.py
+++ b/brainpy/_src/math/jitconn/matvec.py
@@ -1,25 +1,20 @@
# -*- coding: utf-8 -*-
-
-
from typing import Tuple, Optional, Union
import jax
-import numpy as np
-from jax import numpy as jnp
-from jax.interpreters import ad
-from brainpy._src.dependency_check import import_taichi
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.ndarray import Array, _get_dtype
-from brainpy._src.math.op_register import XLACustomOp
-from brainpy.errors import PackageMissingError
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
+from brainpy._src.math.ndarray import Array
-ti = import_taichi(error_if_not_found=False)
+bti = import_braintaichi(error_if_not_found=False)
__all__ = [
- 'mv_prob_homo',
- 'mv_prob_uniform',
- 'mv_prob_normal',
+ 'mv_prob_homo',
+ 'mv_prob_uniform',
+ 'mv_prob_normal',
+ 'get_homo_weight_matrix',
+ 'get_uniform_weight_matrix',
+ 'get_normal_weight_matrix'
]
@@ -33,70 +28,59 @@ def mv_prob_homo(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- r"""Perform the :math:`y=M@v` operation,
- where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position.
-
- This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
- on CPU and GPU devices.
-
- .. warning::
-
- This API may change in the future.
-
- In this operation, :math:`M` is the random matrix with a connection probability
- `conn_prob`, and at each connection the value is the same scalar `weight`.
-
- When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
-
- .. note::
-
- Note that the just-in-time generated :math:`M` (`transpose=False`) is
- different from the generated :math:`M^T` (`transpose=True`).
-
- If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
- matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
- the speed compared with ``outdim_parallel=False``.
-
- Parameters
- ----------
- vector: Array, ndarray
- The vector.
- weight: float
- The value of the random matrix.
- conn_prob: float
- The connection probability.
- shape: tuple of int
- The matrix shape.
- seed: int
- The random number generation seed.
- transpose: bool
- Transpose the random matrix or not.
- outdim_parallel: bool
- Perform the parallel random generations along the out dimension or not.
- It can be used to set the just-in-time generated :math:M^T: is the same
- as the just-in-time generated :math:`M` when ``transpose=True``.
-
- Returns
- -------
- out: Array, ndarray
- The output of :math:`y = M @ v`.
- """
- if ti is None:
- raise PackageMissingError.by_purpose('taichi', purpose='customized operators')
-
- vector = as_jax(vector)
- if isinstance(weight, float):
- weight = as_jax(weight, dtype=vector.dtype)
- weight = jnp.atleast_1d(as_jax(weight))
- conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
- clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
- if seed is None:
- with jax.ensure_compile_time_eval():
- seed = np.random.randint(0, int(1e8), 1)
- seed = jnp.asarray(seed, dtype=jnp.uint32)
- seed = jnp.atleast_1d(seed)
- return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)[0]
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ weight: float
+ The value of the random matrix.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+
+ return bti.jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
def mv_prob_uniform(
@@ -110,72 +94,61 @@ def mv_prob_uniform(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- r"""Perform the :math:`y=M@v` operation,
- where :math:`M` is just-in-time randomly generated with a uniform distribution for its value.
-
- This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
- on CPU and GPU devices.
-
- .. warning::
-
- This API may change in the future.
-
- In this operation, :math:`M` is the random matrix with a connection probability
- `conn_prob`, and at each connection the value is the same scalar `weight`.
-
- When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
-
- .. note::
-
- Note that the just-in-time generated :math:`M` (`transpose=False`) is
- different from the generated :math:`M^T` (`transpose=True`).
-
- If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
- matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
- the speed compared with ``outdim_parallel=False``.
-
- Parameters
- ----------
- vector: Array, ndarray
- The vector.
- w_low: float
- Lower boundary of the output interval.
- w_high: float
- Upper boundary of the output interval.
- conn_prob: float
- The connection probability.
- shape: tuple of int
- The matrix shape.
- seed: int
- The random number generation seed.
- transpose: bool
- Transpose the random matrix or not.
- outdim_parallel: bool
- Perform the parallel random generations along the out dimension or not.
- It can be used to set the just-in-time generated :math:M^T: is the same
- as the just-in-time generated :math:`M` when ``transpose=True``.
-
- Returns
- -------
- out: Array, ndarray
- The output of :math:`y = M @ v`.
- """
- if ti is None:
- raise PackageMissingError.by_purpose('taichi', purpose='customized operators')
-
- vector = as_jax(vector)
- if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype)
- if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype)
- w_low = jnp.atleast_1d(as_jax(w_low))
- w_high = jnp.atleast_1d(as_jax(w_high))
- conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
- conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
- if seed is None:
- with jax.ensure_compile_time_eval():
- seed = np.random.randint(0, int(1e8), 1)
- seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
- return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)[0]
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a uniform distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ w_low: float
+ Lower boundary of the output interval.
+ w_high: float
+ Upper boundary of the output interval.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+
+ return bti.jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
def mv_prob_normal(
@@ -189,732 +162,177 @@ def mv_prob_normal(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- r"""Perform the :math:`y=M@v` operation,
- where :math:`M` is just-in-time randomly generated with a normal distribution for its value.
-
- This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
- on CPU and GPU devices.
-
- .. warning::
-
- This API may change in the future.
-
- In this operation, :math:`M` is the random matrix with a connection probability
- `conn_prob`, and at each connection the value is the same scalar `weight`.
-
- When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
-
- .. note::
-
- Note that the just-in-time generated :math:`M` (`transpose=False`) is
- different from the generated :math:`M^T` (`transpose=True`).
-
- If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
- matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
- the speed compared with ``outdim_parallel=False``.
-
- Parameters
- ----------
- vector: Array, ndarray
- The vector.
- w_mu: float
- Mean (centre) of the distribution.
- w_sigma: float
- Standard deviation (spread or “width”) of the distribution. Must be non-negative.
- conn_prob: float
- The connection probability.
- shape: tuple of int
- The matrix shape.
- seed: int
- The random number generation seed.
- transpose: bool
- Transpose the random matrix or not.
- outdim_parallel: bool
- Perform the parallel random generations along the out dimension or not.
- It can be used to set the just-in-time generated :math:M^T: is the same
- as the just-in-time generated :math:`M` when ``transpose=True``.
-
- Returns
- -------
- out: Array, ndarray
- The output of :math:`y = M @ v`.
- """
- if ti is None:
- raise PackageMissingError.by_purpose('taichi', purpose='customized operators')
-
- vector = as_jax(vector)
- if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype)
- if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype)
- w_mu = jnp.atleast_1d(as_jax(w_mu))
- w_sigma = jnp.atleast_1d(as_jax(w_sigma))
- conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
- conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
- if seed is None:
- with jax.ensure_compile_time_eval():
- seed = np.random.randint(0, int(1e8), 1)
- seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
- return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)[0]
-
-
-def raw_mv_prob_homo(
- vector: jax.Array,
- weight: jax.Array, # vector with size 1
- clen: jax.Array, # vector with size 1
- seed: jax.Array, # vector with size 1
+ r"""Perform the :math:`y=M@v` operation,
+ where :math:`M` is just-in-time randomly generated with a normal distribution for its value.
+
+ This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations
+ on CPU and GPU devices.
+
+ .. warning::
+
+ This API may change in the future.
+
+ In this operation, :math:`M` is the random matrix with a connection probability
+ `conn_prob`, and at each connection the value is the same scalar `weight`.
+
+ When ``transpose=True``, we perform an operation of :math:`y=M^T@v`.
+
+ .. note::
+
+ Note that the just-in-time generated :math:`M` (`transpose=False`) is
+ different from the generated :math:`M^T` (`transpose=True`).
+
+ If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time
+ matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of
+ the speed compared with ``outdim_parallel=False``.
+
+ Parameters
+ ----------
+ vector: Array, ndarray
+ The vector.
+ w_mu: float
+ Mean (centre) of the distribution.
+ w_sigma: float
+ Standard deviation (spread or “width”) of the distribution. Must be non-negative.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The output of :math:`y = M @ v`.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+ return bti.jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def get_homo_weight_matrix(
+ weight: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight)
-
- if outdim_parallel:
- prim = _mv_prob_homo_outdim_parallel_p
- else:
- prim = _mv_prob_homo_p
-
- return prim(vector,
- weight,
- clen,
- seed,
- outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)],
- shape=mat_shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
-def raw_mv_prob_uniform(
- vector: jax.Array,
- w_low: jax.Array,
- w_high: jax.Array,
- conn_len: jax.Array,
- seed: jax.Array,
+ r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`.
+
+ Parameters
+ ----------
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The connection matrix :math:`M`.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+ return bti.get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose,
+ outdim_parallel=outdim_parallel)
+
+
+def get_uniform_weight_matrix(
+ w_low: float,
+ w_high: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high)
-
- if outdim_parallel:
- prim = _mv_prob_uniform_outdim_parallel_p
- else:
- prim = _mv_prob_uniform_p
-
- return prim(vector,
- w_low,
- w_high,
- conn_len,
- seed,
- outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)],
- shape=mat_shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
-def raw_mv_prob_normal(
- vector: jax.Array,
- w_mu: jax.Array,
- w_sigma: jax.Array,
- conn_len: jax.Array,
- seed: jax.Array,
+ r"""Get the weight matrix :math:`M` with a uniform distribution for its value.
+
+ Parameters
+ ----------
+ w_low: float
+ Lower boundary of the output interval.
+ w_high: float
+ Upper boundary of the output interval.
+ conn_prob: float
+ The connection probability.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The weight matrix :math:`M`.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+ return bti.get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+
+
+def get_normal_weight_matrix(
+ w_mu: float,
+ w_sigma: float,
+ conn_prob: float,
+ seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
- mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma)
-
- if outdim_parallel:
- prim = _mv_prob_normal_outdim_parallel_p
- else:
- prim = _mv_prob_normal_p
-
- return prim(vector,
- w_mu,
- w_sigma,
- conn_len,
- seed,
- outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)],
- shape=mat_shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
-def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
- if vector.ndim != 1:
- raise ValueError('vector should be a 1D vector.')
- if len(shape) != 2:
- raise ValueError('shape should be a length-2 tuple.')
- if seed.ndim != 1:
- raise ValueError('seed must be a 1D scalar.')
- if clen.ndim != 1:
- raise ValueError('conn_prob must be a 1D scalar.')
-
- assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64]
- assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64]
-
- for weight in weights:
- if weight.ndim != 1:
- raise ValueError('weight must be a 1D scalar.')
- assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.'
-
- if not isinstance(outdim_parallel, bool):
- raise ValueError('outdim_parallel must be boolean value.')
- if not isinstance(transpose, bool):
- raise ValueError('transpose must be boolean value.')
-
- if transpose:
- out_shape = (shape[1],)
- if vector.shape[0] != shape[0]:
- raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.')
- shape = _reverse(shape)
- else:
- if vector.shape[0] != shape[1]:
- raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).')
- out_shape = (shape[0],)
-
- return shape, out_shape
-
-
-def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
- assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64]
- return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights)
-
-
-def _mv_prob_homo_transpose(
- ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel
-):
- shape = _reverse(shape) if transpose else shape
- if ad.is_undefined_primal(vector):
- if type(ct) is ad.Zero:
- return ad.Zero(vector), weight, clen, seed
- else:
- dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape,
- transpose=not transpose, outdim_parallel=not outdim_parallel)[0]
- return dv, weight, clen, seed
- elif ad.is_undefined_primal(weight):
- if type(ct) is ad.Zero:
- return vector, ad.Zero(weight), clen, seed
- else:
- row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed,
- shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0]
- dw = jnp.sum(row * vector, keepdims=True)
- return vector, dw, clen, seed
- else:
- assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.'
- assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.'
-
-
-def _mv_prob_uniform_transpose(
- ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel
-):
- shape = _reverse(shape) if transpose else shape
- if ad.is_undefined_primal(vector):
- if type(ct) is ad.Zero:
- return ad.Zero(vector), w_low, w_high, clen, seed
- else:
- dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape,
- transpose=not transpose, outdim_parallel=not outdim_parallel)[0]
- return dv, w_low, w_high, clen, seed
- else:
- assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.'
- assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.'
- assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.'
- assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.'
-
-
-def _mv_prob_normal_transpose(
- ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel
-):
- shape = _reverse(shape) if transpose else shape
- if ad.is_undefined_primal(vector):
- if type(ct) is ad.Zero:
- return ad.Zero(vector), w_mu, w_sigma, clen, seed
- else:
- dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape,
- transpose=not transpose, outdim_parallel=not outdim_parallel)[0]
- return dv, w_mu, w_sigma, clen, seed
- else:
- assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.'
- assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.'
- assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.'
- assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.'
-
-
-def _reverse(shape):
- return shape[::-1]
-
-
-if ti is not None:
- from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal)
-
-
- @ti.kernel
- def _mv_prob_homo_cpu(
- vector: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- v = vector[i_col] * weight0
- while i_row < num_row:
- out[i_row] += v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _mv_prob_homo_outdim_parallel_cpu(
- vector: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- r += vector[i_col]
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r * weight0
-
-
- @ti.kernel
- def _mv_prob_homo_gpu(
- vector: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- index = i & 31
- col_v = vector[i_col]
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- out[i_row] += weight0 * col_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _mv_prob_homo_outdim_parallel_gpu(
- vector: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- weight0 = weight[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.u32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- i_thread = i & 31
- i_col = step * i_thread - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- r += vector[i_col]
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += weight0 * r # TODO: warp-level reduction
-
-
- def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
- def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose,
- outdim_parallel=outdim_parallel)
-
-
- def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None)
- prim.def_transpose_rule(_mv_prob_homo_transpose)
- return prim
-
-
- # outdim_parallel = True
- _mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu,
- gpu_kernel=_mv_prob_homo_outdim_parallel_gpu)
-
- # outdim_parallel = False
- _mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu,
- gpu_kernel=_mv_prob_homo_gpu)
-
-
- @ti.kernel
- def _mv_prob_uniform_cpu(
- vector: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- col_v = vector[i_col]
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- key, raw_v = lfsr88_uniform(key, w_min0, w_max0)
- out[i_row] += col_v * raw_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _mv_prob_uniform_outdim_parallel_cpu(
- vector: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- key, raw_v = lfsr88_uniform(key, w_min0, w_max0)
- r += vector[i_col] * raw_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r
-
-
- @ti.kernel
- def _mv_prob_uniform_gpu(
- vector: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- index = i & 31
- col_v = vector[i_col]
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- out[i_row] += row_v * col_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _mv_prob_uniform_outdim_parallel_gpu(
- vector: ti.types.ndarray(ndim=1),
- w_min: ti.types.ndarray(ndim=1),
- w_max: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_min0 = w_min[0]
- w_max0 = w_max[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.u32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- i_thread = i & 31
- i_col = step * i_thread - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- key, row_v = lfsr88_uniform(key, w_min0, w_max0)
- r += vector[i_col] * row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *,
- outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *,
- outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *,
- outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_mv_prob_uniform_jvp_vector,
- _mv_prob_uniform_jvp_wlow,
- _mv_prob_uniform_jvp_whigh,
- None,
- None)
- prim.def_transpose_rule(_mv_prob_uniform_transpose)
- return prim
-
-
- # outdim_parallel = True
- _mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim(
- cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu,
- gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu
- )
-
- # outdim_parallel = False
- _mv_prob_uniform_p = _define_mv_prob_uniform_prim(
- cpu_kernel=_mv_prob_uniform_cpu,
- gpu_kernel=_mv_prob_uniform_gpu
- )
-
-
- @ti.kernel
- def _mv_prob_normal_cpu(
- vector: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_col in range(num_col):
- col_v = vector[i_col]
- key = lfsr88_key(seed0 + i_col)
- key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_row < num_row:
- key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
- out[i_row] += col_v * raw_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _mv_prob_normal_outdim_parallel_cpu(
- vector: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
-
- for i_row in range(num_row):
- r = 0.
- key = lfsr88_key(seed0 + i_row)
- key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
- while i_col < num_col:
- key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
- r += vector[i_col] * raw_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] = r
-
-
- @ti.kernel
- def _mv_prob_normal_gpu(
- vector: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.uint32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_col * 32):
- i_col = i >> 5
- index = i & 31
- col_v = vector[i_col]
- i_row = step * index - 1
- end = ti.min(i_row + step, num_row)
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
- while i_row < end:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- out[i_row] += row_v * col_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_row += inc
-
-
- @ti.kernel
- def _mv_prob_normal_outdim_parallel_gpu(
- vector: ti.types.ndarray(ndim=1),
- w_mu: ti.types.ndarray(ndim=1),
- w_sigma: ti.types.ndarray(ndim=1),
- clen: ti.types.ndarray(ndim=1),
- seed: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)
- ):
- num_row = out.shape[0]
- num_col = vector.shape[0]
- w_mu0 = w_mu[0]
- w_sigma0 = w_sigma[0]
- clen0 = clen[0]
- seed0 = seed[0]
- step = ti.u32(ti.max((num_row + 1) >> 5, 1))
-
- for i in range(num_row * 32):
- i_row = i >> 5
- i_thread = i & 31
- i_col = step * i_thread - 1
- end_col = ti.min(i_col + step, num_col)
- r = 0.
- key = lfsr88_key(seed0 + i)
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- while i_col < end_col:
- key, row_v = lfsr88_normal(key, w_mu0, w_sigma0)
- r += vector[i_col] * row_v
- key, inc = lfsr88_random_integers(key, 1, clen0)
- i_col += inc
- out[i_row] += r # TODO: warp-level reduction
-
-
- def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel):
- shape = _reverse(shape) if transpose else shape
- return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape,
- transpose=transpose, outdim_parallel=outdim_parallel)
-
-
- def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_mv_prob_normal_jvp_vector,
- _mv_prob_normal_jvp_w_mu,
- _mv_prob_normal_jvp_w_sigma,
- None,
- None)
- prim.def_transpose_rule(_mv_prob_normal_transpose)
- return prim
-
-
- # outdim_parallel = True
- _mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim(
- cpu_kernel=_mv_prob_normal_outdim_parallel_cpu,
- gpu_kernel=_mv_prob_normal_outdim_parallel_gpu
- )
-
- # outdim_parallel = False
- _mv_prob_normal_p = _define_mv_prob_normal_prim(
- cpu_kernel=_mv_prob_normal_cpu,
- gpu_kernel=_mv_prob_normal_gpu
- )
+ r"""Get the weight matrix :math:`M` with a normal distribution for its value.
+
+ Parameters
+ ----------
+ w_mu: float
+ Mean (centre) of the distribution.
+ w_sigma: float
+ Standard deviation (spread or “width”) of the distribution. Must be non-negative.
+ shape: tuple of int
+ The matrix shape.
+ seed: int
+ The random number generation seed.
+ transpose: bool
+ Transpose the random matrix or not.
+ outdim_parallel: bool
+ Perform the parallel random generations along the out dimension or not.
+ It can be used to set the just-in-time generated :math:M^T: is the same
+ as the just-in-time generated :math:`M` when ``transpose=True``.
+
+ Returns
+ -------
+ out: Array, ndarray
+ The weight matrix :math:`M`.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+ return bti.get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed,
+ shape=shape,
+ transpose=transpose, outdim_parallel=outdim_parallel)
diff --git a/brainpy/_src/math/jitconn/tests/event_matvec_jitconn_performance.py b/brainpy/_src/math/jitconn/tests/event_matvec_jitconn_performance.py
deleted file mode 100644
index 2c1ca7110..000000000
--- a/brainpy/_src/math/jitconn/tests/event_matvec_jitconn_performance.py
+++ /dev/null
@@ -1,245 +0,0 @@
-from time import time
-
-from jax import jit
-
-import brainpy as bp
-import brainpy.math as bm
-
-
-def compare_sparse_ops(platform='cpu'):
- """
-
- GPU
- ---
- shape = (1000, 1000), prob = 0.1, transpose = True
- csr sparse 0.09568500518798828 s
- jit conn 0.12936949729919434 s
-
- shape = (1000, 1000), prob = 0.1, transpose = False
- csr sparse 0.09957313537597656 s
- jit conn 0.1456453800201416 s
-
- shape = (1000, 1000), prob = 0.2, transpose = True
- csr sparse 0.1014559268951416 s
- jit conn 0.16193556785583496 s
-
- shape = (1000, 1000), prob = 0.2, transpose = False
- csr sparse 0.10938715934753418 s
- jit conn 0.14464354515075684 s
-
- shape = (1000, 1000), prob = 0.4, transpose = True
- csr sparse 0.14374589920043945 s
- jit conn 0.1551048755645752 s
-
- shape = (1000, 1000), prob = 0.4, transpose = False
- csr sparse 0.14356279373168945 s
- jit conn 0.15198969841003418 s
-
- shape = (1000, 1000), prob = 0.6, transpose = True
- csr sparse 0.1429135799407959 s
- jit conn 0.15459179878234863 s
-
- shape = (1000, 1000), prob = 0.6, transpose = False
- csr sparse 0.14870882034301758 s
- jit conn 0.15899157524108887 s
-
- shape = (1000, 1000), prob = 0.8, transpose = True
- csr sparse 0.1489548683166504 s
- jit conn 0.1636965274810791 s
-
- shape = (1000, 1000), prob = 0.8, transpose = False
- csr sparse 0.09073925018310547 s
- jit conn 0.17296433448791504 s
-
- shape = (1000, 10000), prob = 0.1, transpose = True
- csr sparse 0.14572954177856445 s
- jit conn 0.15570378303527832 s
-
- shape = (1000, 10000), prob = 0.1, transpose = False
- csr sparse 0.14201974868774414 s
- jit conn 0.2694075107574463 s
-
- shape = (1000, 10000), prob = 0.2, transpose = True
- csr sparse 0.1480388641357422 s
- jit conn 0.14784669876098633 s
-
- shape = (1000, 10000), prob = 0.2, transpose = False
- csr sparse 0.14451289176940918 s
- jit conn 0.4144716262817383 s
-
- shape = (1000, 10000), prob = 0.4, transpose = True
- csr sparse 0.14377927780151367 s
- jit conn 0.15256381034851074 s
-
- shape = (1000, 10000), prob = 0.4, transpose = False
- csr sparse 0.1487278938293457 s
- jit conn 0.41004467010498047 s
-
- shape = (1000, 10000), prob = 0.6, transpose = True
- csr sparse 0.1689896583557129 s
- jit conn 0.18367314338684082 s
-
- shape = (1000, 10000), prob = 0.6, transpose = False
- csr sparse 0.15153169631958008 s
- jit conn 0.4159865379333496 s
-
- shape = (1000, 10000), prob = 0.8, transpose = True
- csr sparse 0.15267014503479004 s
- jit conn 0.16814088821411133 s
-
- shape = (1000, 10000), prob = 0.8, transpose = False
- csr sparse 0.1320178508758545 s
- jit conn 0.5114090442657471 s
-
- shape = (10000, 10000), prob = 0.1, transpose = True
- csr sparse 0.15414834022521973 s
- jit conn 0.15847539901733398 s
-
- shape = (10000, 10000), prob = 0.1, transpose = False
- csr sparse 0.1557462215423584 s
- jit conn 0.18897342681884766 s
-
- shape = (10000, 10000), prob = 0.2, transpose = True
- csr sparse 0.28719663619995117 s
- jit conn 0.3945181369781494 s
-
- shape = (10000, 10000), prob = 0.2, transpose = False
- csr sparse 0.29045557975769043 s
- jit conn 0.2662692070007324 s
-
- shape = (10000, 10000), prob = 0.4, transpose = True
- csr sparse 0.26814866065979004 s
- jit conn 0.41262269020080566 s
-
- shape = (10000, 10000), prob = 0.4, transpose = False
- csr sparse 0.14010882377624512 s
- jit conn 0.30821704864501953 s
-
- shape = (10000, 10000), prob = 0.6, transpose = True
- csr sparse 0.34110474586486816 s
- jit conn 0.44765257835388184 s
-
- shape = (10000, 10000), prob = 0.6, transpose = False
- csr sparse 0.14516901969909668 s
- jit conn 0.42423462867736816 s
-
- shape = (10000, 10000), prob = 0.8, transpose = True
- csr sparse 0.38806986808776855 s
- jit conn 0.5052323341369629 s
-
- shape = (10000, 10000), prob = 0.8, transpose = False
- csr sparse 0.13016152381896973 s
- jit conn 0.4791419506072998 s
-
- shape = (50000, 50000), prob = 0.1, transpose = True
- csr sparse 0.1485145092010498 s
- jit conn 0.6013796329498291 s
-
- shape = (50000, 50000), prob = 0.1, transpose = False
- csr sparse 0.2520942687988281 s
- jit conn 0.5886740684509277 s
-
- shape = (50000, 50000), prob = 0.2, transpose = True
- csr sparse 0.41227173805236816 s
- jit conn 1.0801291465759277 s
-
- shape = (50000, 50000), prob = 0.2, transpose = False
- csr sparse 0.5962152481079102 s
- jit conn 1.1053071022033691 s
-
- shape = (50000, 50000), prob = 0.4, transpose = True
- Killed
- """
-
- bm.set_platform(platform)
-
- weight = 1.
- seed = 1234
-
- all_shapes = [
- (int(1e3), int(1e3)),
- (int(1e3), int(1e4)),
- (int(1e4), int(1e4)),
- (int(5e4), int(5e4)),
- (int(5e4), int(1e5)),
- ]
-
- for shape in all_shapes:
- for prob in [0.1, 0.2, 0.4, 0.6, 0.8]:
- indices, indptr = bp.conn.FixedProb(prob, pre=shape[0], post=shape[1]).require('csr')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- for transpose in [True, False]:
- print(f'shape = {shape}, prob = {prob}, transpose = {transpose}')
- f_sparse = jit(lambda e: bm.event.csrmv(weight, indices, indptr, e,
- shape=shape, transpose=transpose))
- f_jitconn = jit(lambda e: bm.jitconn.event_mv_prob_homo(
- e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose))
-
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]).value < prob
- f_sparse(events).block_until_ready()
- f_jitconn(events).block_until_ready()
-
- t0 = time()
- for _ in range(100):
- f_sparse(events).block_until_ready()
- print(f'csr sparse {time() - t0} s')
-
- t0 = time()
- for _ in range(100):
- f_jitconn(events).block_until_ready()
- print(f'jit conn {time() - t0} s')
-
- print()
- bm.clear_buffer_memory()
-
-
-def compare_jitconn_imp(platform='gpu'):
- bm.set_platform(platform)
-
- weight = 1.
- seed = 1234
-
- all_shapes = [
- (int(1e3), int(1e3)),
- (int(1e3), int(1e4)),
- (int(1e4), int(1e4)),
- (int(5e4), int(5e4)),
- (int(5e4), int(1e5)),
- (int(5e5), int(1e5)),
- (int(5e5), int(5e5)),
- ]
-
- for shape in all_shapes:
- for prob in [0.01, 0.05, 0.1, 0.2, 0.4, 0.8]:
- for transpose in [True, False]:
- print(f'shape = {shape}, prob = {prob}, transpose = {transpose}')
- # f1 = jit(lambda e: event_matvec_prob_conn_homo_weight_v1(
- # e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose))
- f2 = jit(lambda e: bm.jitconn.event_mv_prob_homo(
- e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose))
-
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]).value < prob
- # f1(events).block_until_ready()
- f2(events).block_until_ready()
-
- # t0 = time()
- # for _ in range(100):
- # f1(events).block_until_ready()
- # print(f'event_matvec_v1 {time() - t0} s')
-
- t0 = time()
- for _ in range(100):
- f2(events).block_until_ready()
- print(f'event_matvec_v2 {time() - t0} s')
- print()
- bm.clear_buffer_memory()
-
-
-if __name__ == '__main__':
- pass
- # compare_where('cpu')
- # compare_sparse_ops('gpu')
- # compare_jitconn_imp('gpu')
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py
deleted file mode 100644
index 21a246650..000000000
--- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py
+++ /dev/null
@@ -1,573 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('cpu')
-
-seed = 1234
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
- ]
-types = [
- 'homo',
- 'uniform',
- 'normal'
- ]
-transpose = [
- True,
- False
- ]
-outdim_parallel = [
- True,
- False,
- ]
-bool_event = [
- True,
- False
- ]
-conn_prob = 0.05
-homo_data = 1.
-w_low = 0.
-w_high = 1.
-w_mu = 0.
-w_sigma = 0.1
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-print(bm.get_platform())
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_event_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.event_mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0]
- return r
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_event_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.event_mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0]
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.event_mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.event_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.event_mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.event_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-
-def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
-
- # groundtruth = bm.as_jax(events, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-
-def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event):
- print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
- if _type == 'homo':
- return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event)
- elif _type == 'uniform':
- return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event)
- elif _type == 'normal':
- return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event)
- else:
- raise ValueError
-
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- for _bool_event in bool_event:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/jitconn_event_matvec_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- for _bool_event in bool_event:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/jitconn_event_matvec_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py
deleted file mode 100644
index ff4f01afc..000000000
--- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py
+++ /dev/null
@@ -1,589 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('cpu')
-# bm.disable_gpu_memory_preallocation()
-
-seed = 1234
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
- ]
-types = [
- 'homo',
- 'uniform',
- 'normal'
- ]
-transpose = [
- True,
- False
- ]
-outdim_parallel = [
- True,
- False,
- ]
-bool_event = [
- True,
- False
- ]
-conn_prob = 0.05
-homo_data = 1.
-w_low = 0.
-w_high = 1.
-w_mu = 0.
-w_sigma = 0.1
-
-print(bm.get_platform())
-
-def sum_op(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)[0]
- return r.sum()
-
- return func
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_event_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r +=jax.grad(sum_op(bm.jitconn.event_mv_prob_homo_taichi), argnums=0)(
- vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_event_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.event_mv_prob_homo), argnums=0)(
- vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform_taichi), argnums=0)(
- vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform), argnums=0)(
- vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal_taichi), argnums=0)(
- vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_event_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal), argnums=0)(
- vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event):
- print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
- if _type == 'homo':
- return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event)
- elif _type == 'uniform':
- return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event)
- elif _type == 'normal':
- return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event)
- else:
- raise ValueError
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
-
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- for _bool_event in bool_event:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/jitconn_event_matvec_grad_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- for _bool_event in bool_event:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/jitconn_event_matvec_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py
deleted file mode 100644
index 14a19aefb..000000000
--- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py
+++ /dev/null
@@ -1,560 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('gpu')
-
-seed = 1234
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
- ]
-types = [
- 'homo',
- 'uniform',
- 'normal'
- ]
-transpose = [
- True,
- False
- ]
-outdim_parallel = [
- True,
- False,
- ]
-bool_event = False
-conn_prob = 0.05
-homo_data = 1.
-w_low = 0.
-w_high = 1.
-w_mu = 0.
-w_sigma = 0.1
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-print(bm.get_platform())
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += bm.jitconn.mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
- return r
-
-def test_jitconn_matvec_homo(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec_normal(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-def test_jitconn_matvec(shape, _type, transpose, outdim_parallel):
- print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
- if _type == 'homo':
- return test_jitconn_matvec_homo(shape, transpose, outdim_parallel)
- elif _type == 'uniform':
- return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel)
- elif _type == 'normal':
- return test_jitconn_matvec_normal(shape, transpose, outdim_parallel)
- else:
- raise ValueError
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/jitconn_matvec_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/jitconn_matvec_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py
deleted file mode 100644
index 165c9b19b..000000000
--- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py
+++ /dev/null
@@ -1,736 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('cpu')
-
-seed = 1234
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
- ]
-bool_event = False
-types = [
- 'homo',
- 'uniform',
- 'normal'
- ]
-transpose = [
- True,
- False
- ]
-outdim_parallel = [
- True,
- False,
- ]
-conn_prob = 0.05
-homo_data = 1.
-w_low = 0.
-w_high = 1.
-w_mu = 0.
-w_sigma = 0.1
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-print(bm.get_platform())
-
-def sum_op(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)[0]
- return r.sum()
-
- return func
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)(
- vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(4, 5, 6))
-def jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)(
- vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)(
- vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)(
- vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)(
- vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-@partial(jax.jit, static_argnums=(5, 6, 7))
-def jitconn_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)(
- vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
- )
- return r
-
-def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- # time.sleep(2)
-
- time0 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
- # time.sleep(2)
-
- time2 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
- # time.sleep(2)
-
- time4 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
- # time.sleep(2)
-
- time6 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-
- time12 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
- # time.sleep(2)
-
- time14 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
- # time.sleep(2)
-
- time16 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
- # time.sleep(2)
-
- time18 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
- time20 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- brainpy_time1 = (time13 - time12) * 1000
- brainpy_time2 = (time15 - time14) * 1000
- brainpy_time3 = (time17 - time16) * 1000
- brainpy_time4 = (time19 - time18) * 1000
- brainpy_time5 = (time21 - time20) * 1000
-
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_2: ', taichi_aot_time2, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_4: ', taichi_aot_time4, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('brainpylib_cpu_1: ', brainpy_time1, 'ms')
- print('brainpylib_cpu_2: ', brainpy_time2, 'ms')
- print('brainpylib_cpu_3: ', brainpy_time3, 'ms')
- print('brainpylib_cpu_4: ', brainpy_time4, 'ms')
- print('brainpylib_cpu_5: ', brainpy_time5, 'ms')
- # assert(jnp.allclose(result1[0], result2))
-
- speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
- (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
-
-def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- # time.sleep(2)
-
- time0 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time1 = time.time()
- # time.sleep(2)
-
- time2 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time3 = time.time()
- # time.sleep(2)
-
- time4 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time5 = time.time()
- # time.sleep(2)
-
- time6 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time7 = time.time()
-
- time8 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time9 = time.time()
-
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
-# print(result1[0])
-# print(result2)
-# print(groundtruth - result1[0])
-# print(groundtruth - result2)
-
- # print(result1[0] - result2)
- # print(bm.allclose(groundtruth, result1[0]))
- # print(bm.allclose(groundtruth, result2))
- # assert bm.allclose(result1[0], result2)
-
- time12 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time13 = time.time()
- # time.sleep(2)
-
- time14 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time15 = time.time()
- # time.sleep(2)
-
- time16 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time17 = time.time()
- # time.sleep(2)
-
- time18 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time19 = time.time()
-
- time20 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time21 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- brainpy_time1 = (time13 - time12) * 1000
- brainpy_time2 = (time15 - time14) * 1000
- brainpy_time3 = (time17 - time16) * 1000
- brainpy_time4 = (time19 - time18) * 1000
- brainpy_time5 = (time21 - time20) * 1000
-
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_2: ', taichi_aot_time2, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_4: ', taichi_aot_time4, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('brainpylib_cpu_1: ', brainpy_time1, 'ms')
- print('brainpylib_cpu_2: ', brainpy_time2, 'ms')
- print('brainpylib_cpu_3: ', brainpy_time3, 'ms')
- print('brainpylib_cpu_4: ', brainpy_time4, 'ms')
- print('brainpylib_cpu_5: ', brainpy_time5, 'ms')
- # assert(jnp.allclose(result1[0], result2))
-
- speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
- (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
-
-def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- # time.sleep(2)
-
- time0 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time1 = time.time()
- # time.sleep(2)
-
- time2 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time3 = time.time()
- # time.sleep(2)
-
- time4 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time5 = time.time()
- # time.sleep(2)
-
- time6 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time7 = time.time()
-
- time8 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time9 = time.time()
-
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
-# print(result1[0])
-# print(result2)
-# print(groundtruth - result1[0])
-# print(groundtruth - result2)
-
- # print(result1[0] - result2)
- # print(bm.allclose(groundtruth, result1[0]))
- # print(bm.allclose(groundtruth, result2))
- # assert bm.allclose(result1[0], result2)
-
- time12 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time13 = time.time()
- # time.sleep(2)
-
- time14 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time15 = time.time()
- # time.sleep(2)
-
- time16 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time17 = time.time()
- # time.sleep(2)
-
- time18 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time19 = time.time()
-
- time20 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time21 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- brainpy_time1 = (time13 - time12) * 1000
- brainpy_time2 = (time15 - time14) * 1000
- brainpy_time3 = (time17 - time16) * 1000
- brainpy_time4 = (time19 - time18) * 1000
- brainpy_time5 = (time21 - time20) * 1000
-
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_2: ', taichi_aot_time2, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_4: ', taichi_aot_time4, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('brainpylib_cpu_1: ', brainpy_time1, 'ms')
- print('brainpylib_cpu_2: ', brainpy_time2, 'ms')
- print('brainpylib_cpu_3: ', brainpy_time3, 'ms')
- print('brainpylib_cpu_4: ', brainpy_time4, 'ms')
- print('brainpylib_cpu_5: ', brainpy_time5, 'ms')
- # assert(jnp.allclose(result1[0], result2))
-
- speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
- (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
-
-def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- # time.sleep(2)
-
- time0 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time1 = time.time()
- # time.sleep(2)
-
- time2 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time3 = time.time()
- # time.sleep(2)
-
- time4 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time5 = time.time()
- # time.sleep(2)
-
- time6 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time9 = time.time()
-
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
-# print(result1[0])
-# print(result2)
-# print(groundtruth - result1[0])
-# print(groundtruth - result2)
-
- # print(result1[0] - result2)
- # print(bm.allclose(groundtruth, result1[0]))
- # print(bm.allclose(groundtruth, result2))
- # assert bm.allclose(result1[0], result2)
-
- time12 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time13 = time.time()
- # time.sleep(2)
-
- time14 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time15 = time.time()
- # time.sleep(2)
-
- time16 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time17 = time.time()
- # time.sleep(2)
-
- time18 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time19 = time.time()
-
- time20 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose))
- time21 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- brainpy_time1 = (time13 - time12) * 1000
- brainpy_time2 = (time15 - time14) * 1000
- brainpy_time3 = (time17 - time16) * 1000
- brainpy_time4 = (time19 - time18) * 1000
- brainpy_time5 = (time21 - time20) * 1000
-
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_2: ', taichi_aot_time2, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_4: ', taichi_aot_time4, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_2: ', brainpy_time2, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_4: ', brainpy_time4, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- # assert(jnp.allclose(result1[0], result2))
-
- speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
- (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
-
-def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- # time.sleep(2)
-
- time0 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time1 = time.time()
- # time.sleep(2)
-
- time2 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time3 = time.time()
- # time.sleep(2)
-
- time4 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time5 = time.time()
- # time.sleep(2)
-
- time6 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time7 = time.time()
-
- time8 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time9 = time.time()
-
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
-# print(result1[0])
-# print(result2)
-# print(groundtruth - result1[0])
-# print(groundtruth - result2)
-
- # print(result1[0] - result2)
- # print(bm.allclose(groundtruth, result1[0]))
- # print(bm.allclose(groundtruth, result2))
- # assert bm.allclose(result1[0], result2)
-
- time12 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time13 = time.time()
- # time.sleep(2)
-
- time14 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time15 = time.time()
- # time.sleep(2)
-
- time16 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time17 = time.time()
- # time.sleep(2)
-
- time18 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time19 = time.time()
-
- time20 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time21 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- brainpy_time1 = (time13 - time12) * 1000
- brainpy_time2 = (time15 - time14) * 1000
- brainpy_time3 = (time17 - time16) * 1000
- brainpy_time4 = (time19 - time18) * 1000
- brainpy_time5 = (time21 - time20) * 1000
-
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_2: ', taichi_aot_time2, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_4: ', taichi_aot_time4, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_2: ', brainpy_time2, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_4: ', brainpy_time4, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- # assert(jnp.allclose(result1[0], result2))
-
- speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
- (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
-
-def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel):
- rng = bm.random.RandomState(seed=seed)
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
-
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- # time.sleep(2)
-
- time0 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time1 = time.time()
- # time.sleep(2)
-
- time2 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time3 = time.time()
- # time.sleep(2)
-
- time4 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time5 = time.time()
- # time.sleep(2)
-
- time6 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time7 = time.time()
-
- time8 = time.time()
- result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time9 = time.time()
-
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
-# print(result1[0])
-# print(result2)
-# print(groundtruth - result1[0])
-# print(groundtruth - result2)
-
- # print(result1[0] - result2)
- # print(bm.allclose(groundtruth, result1[0]))
- # print(bm.allclose(groundtruth, result2))
- # assert bm.allclose(result1[0], result2)
-
- time12 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time13 = time.time()
- # time.sleep(2)
-
- time14 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time15 = time.time()
- # time.sleep(2)
-
- time16 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time17 = time.time()
- # time.sleep(2)
-
- time18 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time19 = time.time()
-
- time20 = time.time()
- result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel))
- time21 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- brainpy_time1 = (time13 - time12) * 1000
- brainpy_time2 = (time15 - time14) * 1000
- brainpy_time3 = (time17 - time16) * 1000
- brainpy_time4 = (time19 - time18) * 1000
- brainpy_time5 = (time21 - time20) * 1000
-
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_2: ', taichi_aot_time2, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_4: ', taichi_aot_time4, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_2: ', brainpy_time2, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_4: ', brainpy_time4, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- # assert(jnp.allclose(result1[0], result2))
-
- speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
- (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
-
-
-def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel):
- print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
- if _type == 'homo':
- return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel)
- elif _type == 'uniform':
- return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel)
- elif _type == 'normal':
- return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel)
- else:
- raise ValueError
-
-
-def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel):
- print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel)
- if _type == 'homo':
- return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel)
- elif _type == 'uniform':
- return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel)
- elif _type == 'normal':
- return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel)
- else:
- raise ValueError
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'speedup'])
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup]
- df.to_csv(f'{PATH}/jitconn_matvec_grad_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _type in types:
- for _outdim_parallel in outdim_parallel:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel)
- # append to dataframe
- df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup]
- df.to_csv(f'{PATH}/jitconn_matvec_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/jitconn/tests/matmat_jitconn_performance.py b/brainpy/_src/math/jitconn/tests/matmat_jitconn_performance.py
deleted file mode 100644
index e23bd5741..000000000
--- a/brainpy/_src/math/jitconn/tests/matmat_jitconn_performance.py
+++ /dev/null
@@ -1,61 +0,0 @@
-from time import time
-
-import brainpy.math as bm
-import jax.numpy as jnp
-from jax import jit, vmap
-
-
-def compare_jitconn_imp(platform='gpu'):
- bm.set_platform(platform)
-
- seed = 1234
- num_loop = 1
-
- all_shapes = [
- # (int(1e3), int(1e3)),
- # (int(1e3), int(1e4)),
- # (int(1e4), int(1e4)),
- # (int(5e4), int(5e4)),
- # (int(5e4), int(1e5)),
- # (int(5e5), int(1e5)),
- (int(5e5), int(5e5)),
- # (int(1e5), int(1e5)),
- ]
-
- for m in [32, 64, 128, 256]:
- for shape in all_shapes:
- for prob in [0.01]:
- print(f'm = {m}, shape = {shape}, prob = {prob}')
- f1 = jit(
- vmap(lambda a: bm.jitconn.mv_prob_normal(
- a, w_mu=0., w_sigma=0.01, conn_prob=prob, shape=shape, seed=seed, transpose=True
- ))
- )
- f2 = jit(lambda e: bm.jitconn.mm_prob_normal(
- e, w_mu=0., w_sigma=0.01, conn_prob=prob, shape=shape, seed=seed, version='v2'
- ))
-
- rng = bm.random.RandomState()
- mat = bm.as_jax(rng.random((m, shape[0])))
- r1 = f1(mat).block_until_ready()
- r2 = f2(mat).block_until_ready()
- assert r1.shape == r2.shape
- print(jnp.allclose(r1, r2))
-
- t0 = time()
- for _ in range(num_loop):
- f1(mat).block_until_ready()
- print(f'matvec vmap {time() - t0} s')
-
- t0 = time()
- for _ in range(num_loop):
- f2(mat).block_until_ready()
- print(f'matmat {time() - t0} s')
-
- print()
- bm.clear_buffer_memory()
-
-
-if __name__ == '__main__':
- pass
- compare_jitconn_imp('gpu')
diff --git a/brainpy/_src/math/jitconn/tests/matmat_testcase.py b/brainpy/_src/math/jitconn/tests/matmat_testcase.py
deleted file mode 100644
index cfd6e5369..000000000
--- a/brainpy/_src/math/jitconn/tests/matmat_testcase.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-from absl.testing import parameterized
-
-shapes = [(100, 200),
- (200, 200),
- (10, 1000),
- (2, 1000),
- (1000, 10),
- (1000, 2)]
-
-
-class Test_matmat_prob_conn(parameterized.TestCase):
- def __init__(self, *args, platform, **kwargs):
- super(Test_matmat_prob_conn, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
- print()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'shape = {shape}, '
- f'm={m}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}'
- f'x64 = {x64}'),
- shape=shape,
- prob=prob,
- w_low=w_low,
- w_high=w_high,
- x64=x64,
- m=m,
- seed=1234
- )
- for x64 in [True, False]
- for shape in shapes
- for prob in [0.01, 0.05, 0.1, 0.4]
- for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
- for m in [5, 8, 15, 33]
- )
- def test_uniform(self, shape, prob, w_low, w_high, m, seed=None, x64=False):
- print(f'test_uniform: '
- f'shape = {shape}, '
- f'm = {m}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'x64 = {x64}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- matrix = bm.as_jax(rng.random((m, shape[0])))
-
- r1 = bm.jitconn.matmat_prob_conn_uniform_weight(matrix,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- version='v1')
- r2 = bm.jitconn.matmat_prob_conn_uniform_weight(matrix,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- version='v1')
- self.assertTrue(jnp.allclose(r1, r2))
-
- f = jax.vmap(lambda a: bm.jitconn.matvec_prob_conn_uniform_weight(
- a, w_low=w_low, w_high=w_high, conn_prob=prob, shape=shape, seed=seed, transpose=True))
- r3 = f(matrix)
- self.assertTrue(jnp.allclose(r1, r3))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(
- testcase_name=(f'test_normal, shape = {shape}, '
- f'm={m}, '
- f'prob={prob}, '
- f'w_mu = {w_mu}, '
- f'w_sigma = {w_sigma},'
- f'x64={x64}'),
- shape=shape,
- prob=prob,
- w_mu=w_mu,
- w_sigma=w_sigma,
- seed=1234,
- m=m,
- )
- for x64 in [True, False]
- for shape in shapes
- for prob in [0.01, 0.05, 0.1, 0.2]
- for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
- for m in [5, 8, 15, 33]
- )
- def test_normal(self, shape, prob, w_mu, w_sigma, m, seed=None, x64=False):
- print(f'_test_normal: '
- f'shape = {shape}, '
- f'm = {m}, '
- f'prob={prob}, '
- f'w_mu = {w_mu}, '
- f'w_sigma = {w_sigma}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- matrix = bm.as_jax(rng.random((m, shape[0])))
-
- r1 = bm.jitconn.matmat_prob_conn_normal_weight(matrix,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed)
- r2 = bm.jitconn.matmat_prob_conn_normal_weight(matrix,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed)
- self.assertTrue(jnp.allclose(r1, r2))
-
- f = jax.vmap(
- lambda a: bm.jitconn.matvec_prob_conn_normal_weight(
- a, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, shape=shape, seed=seed, transpose=True)
- )
- r3 = f(matrix)
- self.assertTrue(jnp.allclose(r1, r3))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/tests/matvec_jitconn_performance.py b/brainpy/_src/math/jitconn/tests/matvec_jitconn_performance.py
deleted file mode 100644
index ddeb30c21..000000000
--- a/brainpy/_src/math/jitconn/tests/matvec_jitconn_performance.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from time import time
-
-import brainpy.math as bm
-from jax import jit
-
-
-def compare_jitconn_imp(platform='gpu'):
- bm.set_platform(platform)
-
- weight = 1.
- seed = 1234
-
- all_shapes = [
- # (int(1e3), int(1e3)),
- # (int(1e3), int(1e4)),
- # (int(1e4), int(1e4)),
- # (int(5e4), int(5e4)),
- # (int(5e4), int(1e5)),
- (int(5e5), int(1e5)),
- (int(5e5), int(5e5)),
- ]
-
- for shape in all_shapes:
- for prob in [0.01, 0.05, 0.1, 0.2, 0.4, 0.8]:
- for transpose in [True, False]:
- print(f'shape = {shape}, prob = {prob}, transpose = {transpose}')
- f1 = jit(lambda e: bm.jitconn.mv_prob_homo(
- e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose))
- f2 = jit(lambda e: bm.jitconn.mv_prob_homo(
- e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose))
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
- f1(events).block_until_ready()
- f2(events).block_until_ready()
-
- t0 = time()
- for _ in range(100):
- f1(events).block_until_ready()
- print(f'event_matvec_v1 {time() - t0} s')
-
- t0 = time()
- for _ in range(100):
- f2(events).block_until_ready()
- print(f'event_matvec_v2 {time() - t0} s')
-
- print()
- bm.clear_buffer_memory()
-
-
-if __name__ == '__main__':
- pass
- compare_jitconn_imp('gpu')
diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py
deleted file mode 100644
index dd1bafded..000000000
--- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py
+++ /dev/null
@@ -1,434 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import jax
-import jax.numpy as jnp
-import pytest
-from absl.testing import parameterized
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
-
-import platform
-force_test = False # turn on to force test on windows locally
-if platform.system() == 'Windows' and not force_test:
- pytest.skip('skip windows', allow_module_level=True)
-
-
-shapes = [(100, 200), (1000, 10)]
-
-
-class Test_event_matvec_prob_conn(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
- print()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- homo_data=[-1.],
- bool_event=[True, False],
- seed=[1234],
- )
- def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False):
- print(f'_test_homo: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
-
- r1 = bm.jitconn.event_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = bm.jitconn.event_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
-
- # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post')
- # indices = bm.as_jax(indices)
- # indptr = bm.as_jax(indptr)
- # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events,
- # shape=shape, transpose=transpose)
- # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- bool_event=[True, False],
- seed=[1234],
- )
- def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False):
- print(f'_test_homo_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
- weights = bm.as_jax(rng.random(10))
-
- f1 = jax.vmap(
- lambda event, data: bm.jitconn.event_mv_prob_homo(
- event, data, conn_prob=prob, shape=shape, seed=seed,
- transpose=transpose, outdim_parallel=outdim_parallel
- )[0]
- )
- r1 = f1(events, weights)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, weights)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1]
- )
- def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.5
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda event, data: bm.jitconn.event_mv_prob_homo(
- event, data, conn_prob=prob, shape=shape, seed=seed,
- outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(),
- argnums=0
- )
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
-
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
-
- r3 = f1(events, 3.)
- r3 = jax.block_until_ready(r3)
-
- self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6))
- self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- w_low=[-1.],
- w_high=[1.],
- bool_event=[True, False]
- )
- def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
- bool_event=True, seed=1234, x64=False):
- print(f'_test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- r1 = bm.jitconn.event_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = bm.jitconn.event_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- bool_event=[True, False],
- )
- def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob,
- bool_event=True, seed=1234, x64=False):
- print(f'_test_uniform_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- f1 = jax.vmap(
- lambda e: bm.jitconn.event_mv_prob_uniform(e,
- w_low=0.,
- w_high=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- )
-
- r1 = f1(events)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- )
- def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda e, w_high: bm.jitconn.event_mv_prob_uniform(
- e,
- w_low=0.,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose).sum()
- )
-
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6))
- # print(r1)
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1, ],
- w_mu=[0.],
- w_sigma=[0.1],
- bool_event=[True, False],
- )
- def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
- bool_event=True, seed=1234, x64=False):
- print(f'_test_normal: shape = {shape}, '
- f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, '
- f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- r1 = bm.jitconn.event_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = bm.jitconn.event_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose = [True, False],
- x64 = [True, False],
- outdim_parallel = [True, False],
- shape = shapes,
- prob = [0.1],
- bool_event = [True, False],
- )
- def test_normal_vmap(self, shape, transpose, outdim_parallel, prob,
- bool_event=True, seed=1234, x64=False):
- print(f'_test_normal_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e,
- w_mu=0.,
- w_sigma=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose))
- r1 = f1(events)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose = [True, False],
- x64 = [True, False],
- outdim_parallel = [True, False],
- shape = shapes,
- prob = [0.1]
- )
- def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.jit(
- jax.grad(
- lambda e, w_sigma: bm.jitconn.event_mv_prob_normal(
- e,
- w_mu=0.,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose).sum()
- )
- )
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py
deleted file mode 100644
index b2fa77229..000000000
--- a/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py
+++ /dev/null
@@ -1,564 +0,0 @@
-# -*- coding: utf-8 -*-
-from functools import partial
-
-import jax
-import jax.numpy as jnp
-from absl.testing import parameterized
-
-import platform
-import brainpy.math as bm
-
-import pytest
-pytest.skip('Old implementation.', allow_module_level=True)
-is_manual_test = False
-if platform.system() == 'Windows' and not is_manual_test:
- pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True)
-
-shapes = [(100, 200),
- # (10, 1000),
- (2, 1000),
- # (1000, 10),
- (1000, 2)]
-
-brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib')
-taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi')
-brainpylib_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='brainpylib')
-taichi_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='taichi')
-brainpylib_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='brainpylib')
-taichi_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='taichi')
-
-class Test_event_matvec_prob_conn(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
- print()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.01, 0.1, 0.5],
- homo_data=[-1., ],
- bool_event=[True, False],
- seed=[1234],
- )
- def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False):
- print(f'_test_homo: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- if not bool_event:
- events = events.astype(float)
-
- r1 = brainpylib_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = brainpylib_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r3 = brainpylib_mv_prob_homo(events,
- homo_data,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- r3 = jax.block_until_ready(r3)
- self.assertTrue(jnp.allclose(r1, r3))
-
- # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post')
- # indices = bm.as_jax(indices)
- # indptr = bm.as_jax(indptr)
- # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events,
- # shape=shape, transpose=transpose)
- # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.01, 0.1, 0.5],
- bool_event=[True, False],
- seed=[1234],
- )
- def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False):
- print(f'_test_homo_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
- weights = bm.as_jax(rng.random(10))
-
- f1 = jax.vmap(
- lambda event, data: brainpylib_mv_prob_homo(
- event, data, conn_prob=prob, shape=shape, seed=seed,
- transpose=transpose, outdim_parallel=outdim_parallel
- )
- )
- r1 = f1(events, weights)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, weights)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}',
- shape=shape, transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob, seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, 0.5]
- )
- def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.5
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda event, data: brainpylib_mv_prob_homo(
- event, data, conn_prob=prob, shape=shape, seed=seed,
- outdim_parallel=outdim_parallel, transpose=transpose
- ).sum(),
- argnums=0
- )
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
-
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
-
- r3 = f1(events, 3.)
- r3 = jax.block_until_ready(r3)
-
- self.assertTrue(jnp.allclose(r1 * 3., r3))
- self.assertTrue(jnp.allclose(r1 * 2., r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_low=w_low,
- w_high=w_high,
- bool_event=bool_event,
- seed=1234,
- x64=x64
- )
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, 0.4]
- for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
- for bool_event in [True, False]
- )
- def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
- bool_event=True, seed=None, x64=False):
- print(f'_test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- r1 = brainpylib_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = brainpylib_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r3 = brainpylib_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- r3 = jax.block_until_ready(r3)
- self.assertTrue(jnp.allclose(r1, r3))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape, transpose=transpose,
- outdim_parallel=outdim_parallel, prob=prob,
- bool_event=bool_event,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_uniform_vmap: '
- f'shape={shape}, '
- f'transpose={transpose}, '
- f'bool_event={bool_event}, '
- f'outdim_parallel={outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for bool_event in [True, False]
- )
- def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob,
- bool_event=True, seed=None, x64=False):
- print(f'_test_uniform_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- f1 = jax.vmap(
- lambda e: brainpylib_mv_prob_uniform(e,
- w_low=0.,
- w_high=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- )
-
- r1 = f1(events)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- testcase_name=f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda e, w_high: brainpylib_mv_prob_uniform(
- e,
- w_low=0.,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose).sum()
- )
-
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(bm.allclose(r1 * 2., r2))
- # print(r1)
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_mu=w_mu,
- w_sigma=w_sigma,
- bool_event=bool_event,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_normal: '
- f'shape={shape}, '
- f'transpose={transpose}, '
- f'outdim_parallel={outdim_parallel}, '
- f'prob={prob}, '
- f'w_mu={w_mu}, '
- f'w_sigma={w_sigma}, '
- f'bool_event={bool_event}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, ]
- for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
- for bool_event in [True, False]
- )
- def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
- bool_event=True, seed=None, x64=False):
- print(f'_test_normal: shape = {shape}, '
- f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, '
- f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- r1 = brainpylib_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r1 = jax.block_until_ready(r1)
-
- r2 = brainpylib_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r3 = brainpylib_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- r3 = jax.block_until_ready(r3)
- self.assertTrue(jnp.allclose(r1, r3))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- bool_event=bool_event,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_normal_vmap: '
- f'shape={shape}, '
- f'transpose={transpose}, '
- f'outdim_parallel={outdim_parallel}, '
- f'prob={prob}, '
- f'bool_event={bool_event}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for bool_event in [True, False]
- )
- def test_normal_vmap(self, shape, transpose, outdim_parallel, prob,
- bool_event=True, seed=None, x64=False):
- print(f'_test_normal_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1
- events = bm.as_jax(events)
- if not bool_event:
- events = events.astype(float)
-
- f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e,
- w_mu=0.,
- w_sigma=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose))
- r1 = f1(events)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(jnp.allclose(r1, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- x64=x64,
- seed=1234,
- testcase_name=f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}')
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = rng.random(shape[0] if transpose else shape[1]) < 0.1
- events = bm.as_jax(events)
- events = events.astype(float)
-
- f1 = jax.jit(
- jax.grad(
- lambda e, w_sigma: brainpylib_mv_prob_normal(
- e,
- w_mu=0.,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose).sum()
- )
- )
- r1 = f1(events, 1.)
- r1 = jax.block_until_ready(r1)
- r2 = f1(events, 2.)
- r2 = jax.block_until_ready(r2)
- self.assertTrue(bm.allclose(r1 * 2, r2))
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py
deleted file mode 100644
index e42bd3695..000000000
--- a/brainpy/_src/math/jitconn/tests/test_matvec.py
+++ /dev/null
@@ -1,402 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import jax
-import jax.numpy as jnp
-import pytest
-from absl.testing import parameterized
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
-
-import platform
-force_test = False # turn on to force test on windows locally
-if platform.system() == 'Windows' and not force_test:
- pytest.skip('skip windows', allow_module_level=True)
-
-
-shapes = [(100, 200), (1000, 10)]
-
-
-class Test_matvec_prob_conn(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_matvec_prob_conn, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
- print()
-
- @parameterized.product(
- x64=[True, False],
- transpose=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- homo_data=[1.]
- )
- def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False):
- print(f'test_homo: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = bm.jitconn.mv_prob_homo(vector,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
-
- r2 = bm.jitconn.mv_prob_homo(vector,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- )
- def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'test_homo_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
- weights = bm.as_jax(rng.random(10))
-
- f1 = jax.vmap(
- lambda event, data: bm.jitconn.mv_prob_homo(
- event, data,
- conn_prob=prob, shape=shape, seed=seed,
- outdim_parallel=outdim_parallel, transpose=transpose
- )[0]
- )
- r1 = f1(events, weights)
- r2 = f1(events, weights)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- )
- def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda event, data: bm.jitconn.mv_prob_homo(
- event, data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- )[0].sum(),
- argnums=0
- )
- r1 = f1(events, 1.)
- r2 = f1(events, 2.)
-
- self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- x64=[True, False],
- transpose=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- w_low=[-0.1],
- w_high=[1.0],
- )
- def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False):
- print(f'test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'x64 = {x64}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = bm.jitconn.mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
-
- r2 = bm.jitconn.mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- c = jnp.allclose(r1, r2, atol=1e-6)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- )
- def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'test_uniform_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
-
- f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e,
- w_low=0.,
- w_high=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose))
-
- r1 = f1(events)
- r2 = f1(events)
- self.assertTrue(jnp.allclose(r1, r2, atol=1e-6))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- x64=[True, False],
- transpose=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- )
- def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- f1 = jax.grad(
- lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform(
- e,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- )[0].sum()
- )
-
- r1 = f1(events, 0., 1.)
- r2 = f1(events, 0., 2.)
-
- self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1],
- w_mu=[0.],
- w_sigma=[0.2]
- )
- def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False):
- print(f'_test_normal: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_mu = {w_mu}, '
- f'w_sigma = {w_sigma}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = bm.jitconn.mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
-
- r2 = bm.jitconn.mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- c = jnp.allclose(r1, r2, atol=1e-6)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1]
- )
- def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'_test_normal_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
-
- f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e,
- w_mu=0.,
- w_sigma=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose))
- r1 = f1(events)
- r2 = f1(events)
- c = jnp.allclose(r1, r2, atol=1e-6)
- if not c:
- print(r1, r2)
- print(r1 - r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- x64=[True, False],
- outdim_parallel=[True, False],
- shape=shapes,
- prob=[0.1]
- )
- def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False):
- print(f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda e, w_sigma: bm.jitconn.mv_prob_normal(
- e,
- w_mu=0.,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- )[0].sum()
- )
- r1 = f1(events, 1.)
- r2 = f1(events, 2.)
- self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_old.py b/brainpy/_src/math/jitconn/tests/test_matvec_old.py
deleted file mode 100644
index 360711e7b..000000000
--- a/brainpy/_src/math/jitconn/tests/test_matvec_old.py
+++ /dev/null
@@ -1,551 +0,0 @@
-# -*- coding: utf-8 -*-
-from functools import partial
-
-import jax
-import jax.numpy as jnp
-from absl.testing import parameterized
-
-import brainpy.math as bm
-import platform
-import pytest
-
-pytest.skip('Old implementation.', allow_module_level=True)
-is_manual_test = False
-if platform.system() == 'Windows' and not is_manual_test:
- pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
-
-shapes = [(100, 200),
- (10, 1000),
- (2, 1000),
- (1000, 10),
- (1000, 2)]
-
-brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib')
-taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi')
-brainpylib_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='brainpylib')
-taichi_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='taichi')
-brainpylib_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='brainpylib')
-taichi_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='taichi')
-
-class Test_matvec_prob_conn(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_matvec_prob_conn, self).__init__(*args, **kwargs)
- bm.set_platform(platform)
- print()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_homo, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}, '
- f'x64 = {x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- homo_data=homo_data,
- seed=1234)
- for x64 in [True, False]
- for transpose in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for homo_data in [-1., 1.]
- )
- def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False):
- print(f'test_homo: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = brainpylib_mv_prob_homo(vector,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
-
- r2 = brainpylib_mv_prob_homo(vector,
- homo_data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- self.assertTrue(jnp.allclose(r1, r2))
-
- r2 = brainpylib_mv_prob_homo(vector,
- homo_data,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- self.assertTrue(jnp.allclose(r1, r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_homo_vmap, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'test_homo_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
- weights = bm.as_jax(rng.random(10))
-
- f1 = jax.vmap(
- lambda event, data: brainpylib_mv_prob_homo(
- event, data,
- conn_prob=prob, shape=shape, seed=seed,
- outdim_parallel=outdim_parallel, transpose=transpose
- )
- )
- r1 = f1(events, weights)
- r2 = f1(events, weights)
- self.assertTrue(jnp.allclose(r1, r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_homo_grad, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_homo_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda event, data: brainpylib_mv_prob_homo(
- event, data,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- ).sum(),
- argnums=0
- )
- r1 = f1(events, 1.)
- r2 = f1(events, 2.)
-
- self.assertTrue(jnp.allclose(r1 * 2., r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_uniform, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}'
- f'x64 = {x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_low=w_low,
- w_high=w_high,
- x64=x64,
- seed=1234)
- for x64 in [True, False]
- for transpose in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)]
- )
- def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False):
- print(f'test_uniform: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_low = {w_low}, '
- f'w_high = {w_high}, '
- f'x64 = {x64}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = brainpylib_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
-
- r2 = brainpylib_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- r2 = brainpylib_mv_prob_uniform(events,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'test_uniform_vmap, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'test_uniform_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
-
- f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e,
- w_low=0.,
- w_high=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose))
-
- r1 = f1(events)
- r2 = f1(events)
- self.assertTrue(jnp.allclose(r1, r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=(f'test_uniform_grad, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64)
- for x64 in [True, False]
- for transpose in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_uniform_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- f1 = jax.grad(
- lambda e, w_low, w_high: brainpylib_mv_prob_uniform(
- e,
- w_low=w_low,
- w_high=w_high,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- ).sum()
- )
-
- r1 = f1(events, 0., 1.)
- r2 = f1(events, 0., 2.)
-
- self.assertTrue(bm.allclose(r1 * 2., r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(
- testcase_name=(f'test_normal, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_mu = {w_mu}, '
- f'w_sigma = {w_sigma},'
- f'x64={x64}'),
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- w_mu=w_mu,
- w_sigma=w_sigma,
- seed=1234
- )
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)]
- )
- def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False):
- print(f'_test_normal: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'w_mu = {w_mu}, '
- f'w_sigma = {w_sigma}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))
-
- r1 = brainpylib_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
-
- r2 = brainpylib_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- r2 = brainpylib_mv_prob_normal(events,
- w_mu=w_mu,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=(shape[1], shape[0]),
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=not transpose)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(testcase_name=f'test_normal_vmap, shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_normal_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
-
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1])))
-
- f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e,
- w_mu=0.,
- w_sigma=1.,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose))
- r1 = f1(events)
- r2 = f1(events)
- c = jnp.allclose(r1, r2)
- if not c:
- print(r1, r2)
- self.assertTrue(c)
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
-
- @parameterized.named_parameters(
- dict(shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- seed=1234,
- x64=x64,
- testcase_name=f'test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'x64={x64}')
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1]
- )
- def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False):
- print(f'_test_normal_grad: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}')
-
- if x64:
- bm.enable_x64()
- rng = bm.random.RandomState()
- events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
- events = events.astype(float)
-
- f1 = jax.grad(
- lambda e, w_sigma: brainpylib_mv_prob_normal(
- e,
- w_mu=0.,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose
- ).sum()
- )
- r1 = f1(events, 1.)
- r2 = f1(events, 2.)
- print('r1:', r1)
- print('r2:', r2)
- self.assertTrue(bm.allclose(r1 * 2., r2))
-
- if x64:
- bm.disable_x64()
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/modes.py b/brainpy/_src/math/modes.py
index d46afc248..771d63ea7 100644
--- a/brainpy/_src/math/modes.py
+++ b/brainpy/_src/math/modes.py
@@ -20,7 +20,8 @@ def __repr__(self):
return self.__class__.__name__
def __eq__(self, other: 'Mode'):
- assert isinstance(other, Mode)
+ if not isinstance(other, Mode):
+ return False
return other.__class__ == self.__class__
def is_one_of(self, *modes):
diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py
index 791c8d9fe..b435415d6 100644
--- a/brainpy/_src/math/ndarray.py
+++ b/brainpy/_src/math/ndarray.py
@@ -660,7 +660,7 @@ def searchsorted(self, v, side='left', sorter=None):
"""
return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter))
- def sort(self, axis=-1, kind='quicksort', order=None):
+ def sort(self, axis=-1, stable=True, order=None):
"""Sort an array in-place.
Parameters
@@ -668,11 +668,8 @@ def sort(self, axis=-1, kind='quicksort', order=None):
axis : int, optional
Axis along which to sort. Default is -1, which means sort along the
last axis.
- kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}
- Sorting algorithm. The default is 'quicksort'. Note that both 'stable'
- and 'mergesort' use timsort under the covers and, in general, the
- actual implementation will vary with datatype. The 'mergesort' option
- is retained for backwards compatibility.
+ stable : bool, optional
+ Whether to use a stable sorting algorithm. The default is True.
order : str or list of str, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. A single field can
@@ -680,7 +677,8 @@ def sort(self, axis=-1, kind='quicksort', order=None):
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
- self.value = self.value.sort(axis=axis, kind=kind, order=order)
+ self.value = self.value.sort(axis=axis, stable=stable, order=order)
+
def squeeze(self, axis=None):
"""Remove axes of length one from ``a``."""
diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py
index 2e5e103cc..59509c0c7 100644
--- a/brainpy/_src/math/object_transform/autograd.py
+++ b/brainpy/_src/math/object_transform/autograd.py
@@ -94,7 +94,6 @@ def __init__(
self.target = target
# transform
- self._eval_dyn_vars = False
self._grad_transform = transform
self._dyn_vars = VariableStack()
self._transform = None
@@ -198,20 +197,18 @@ def __call__(self, *args, **kwargs):
)
return self._return(rets)
- elif not self._eval_dyn_vars: # evaluate dynamical variables
- stack = get_stack_cache(self.target)
- if stack is None:
- with VariableStack() as stack:
- rets = eval_shape(self._transform,
- [v.value for v in self._grad_vars], # variables for gradients
- {}, # dynamical variables
- *args,
- **kwargs)
- cache_stack(self.target, stack)
-
+ # evaluate dynamical variables
+ stack = get_stack_cache(self.target)
+ if stack is None:
+ with VariableStack() as stack:
+ rets = eval_shape(self._transform,
+ [v.value for v in self._grad_vars], # variables for gradients
+ {}, # dynamical variables
+ *args,
+ **kwargs)
+ cache_stack(self.target, stack)
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
- self._eval_dyn_vars = True
# if not the outermost transformation
if not stack.is_first_stack():
diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py
index 3edeb08e8..ff3023339 100644
--- a/brainpy/_src/math/object_transform/controls.py
+++ b/brainpy/_src/math/object_transform/controls.py
@@ -7,7 +7,6 @@
import jax
import jax.numpy as jnp
from jax.errors import UnexpectedTracerError
-from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten, tree_unflatten
from tqdm.auto import tqdm
@@ -421,14 +420,14 @@ def call(pred, x=None):
def _warp(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
- return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
+ return jax.tree.map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
return new_f
def _warp_data(data):
def new_f(*args, **kwargs):
- return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
+ return jax.tree.map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
return new_f
@@ -727,7 +726,7 @@ def fun2scan(carry, x):
dyn_vars[k]._value = carry[k]
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
- id_tap(lambda *arg: bar.update(), ())
+ jax.debug.callback(lambda *args: bar.update(), ())
return dyn_vars.dict_data(), results
if remat:
@@ -916,15 +915,15 @@ def fun2scan(carry, x):
dyn_vars[k]._value = dyn_vars_data[k]
carry, results = body_fun(carry, x)
if progress_bar:
- id_tap(lambda *arg: bar.update(), ())
- carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
+ jax.debug.callback(lambda *arg: bar.update(), ())
+ carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results
if remat:
fun2scan = jax.checkpoint(fun2scan)
def call(init, operands):
- init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
+ init = jax.tree.map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
return jax.lax.scan(f=fun2scan,
init=(dyn_vars.dict_data(), init),
xs=operands,
diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py
index 551a0949c..764ce1ee5 100644
--- a/brainpy/_src/math/object_transform/jit.py
+++ b/brainpy/_src/math/object_transform/jit.py
@@ -491,9 +491,8 @@ def call_fun(self, *args, **kwargs):
return call_fun
-
def _make_transform(fun, stack):
- @wraps(fun)
+ # @wraps(fun)
def _transform_function(variable_data: Dict, *args, **kwargs):
for key, v in stack.items():
v._value = variable_data[key]
diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py
index 1181e003b..717b9af8a 100644
--- a/brainpy/_src/math/object_transform/naming.py
+++ b/brainpy/_src/math/object_transform/naming.py
@@ -25,7 +25,7 @@ def check_name_uniqueness(name, obj):
f'In BrainPy, each object should have a unique name. '
f'However, we detect that {obj} has a used name "{name}". \n'
f'If you try to run multiple trials, you may need \n\n'
- f'>>> brainpy.brainpy_object.clear_name_cache() \n\n'
+ f'>>> brainpy.math.clear_name_cache() \n\n'
f'to clear all cached names. '
)
else:
diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py
index 90829d80e..bb4adf1d0 100644
--- a/brainpy/_src/math/object_transform/tests/test_autograd.py
+++ b/brainpy/_src/math/object_transform/tests/test_autograd.py
@@ -86,6 +86,17 @@ def call(a, b, c):
assert aux[1] == bm.exp(0.1)
+ def test_grad_jit(self):
+ def call(a, b, c): return bm.sum(a + b + c)
+
+ bm.random.seed(1)
+ a = bm.ones(10)
+ b = bm.random.randn(10)
+ c = bm.random.uniform(size=10)
+ f_grad = bm.jit(bm.grad(call))
+ assert (f_grad(a, b, c) == 1.).all()
+
+
class TestObjectFuncGrad(unittest.TestCase):
def test_grad_ob1(self):
class Test(bp.BrainPyObject):
@@ -1172,52 +1183,52 @@ def f(a, b):
-class TestHessian(unittest.TestCase):
- def test_hessian5(self):
- bm.set_mode(bm.training_mode)
-
- class RNN(bp.DynamicalSystem):
- def __init__(self, num_in, num_hidden):
- super(RNN, self).__init__()
- self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
- self.out = bp.dnn.Dense(num_hidden, 1)
-
- def update(self, x):
- return self.out(self.rnn(x))
-
- # define the loss function
- def lossfunc(inputs, targets):
- runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
- predicts = runner.predict(inputs)
- loss = bp.losses.mean_squared_error(predicts, targets)
- return loss
-
- model = RNN(1, 2)
- data_x = bm.random.rand(1, 1000, 1)
- data_y = data_x + bm.random.randn(1, 1000, 1)
-
- bp.reset_state(model, 1)
- losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
- hess_matrix = losshess(data_x, data_y)
-
- weights = model.train_vars().unique()
-
- # define the loss function
- def loss_func_for_jax(weight_vals, inputs, targets):
- for k, v in weight_vals.items():
- weights[k].value = v
- runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
- predicts = runner.predict(inputs)
- loss = bp.losses.mean_squared_error(predicts, targets)
- return loss
-
- bp.reset_state(model, 1)
- jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)
-
- for k, v in hess_matrix.items():
- for kk, vv in v.items():
- self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))
-
- bm.clear_buffer_memory()
+# class TestHessian(unittest.TestCase):
+# def test_hessian5(self):
+# bm.set_mode(bm.training_mode)
+#
+# class RNN(bp.DynamicalSystem):
+# def __init__(self, num_in, num_hidden):
+# super(RNN, self).__init__()
+# self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
+# self.out = bp.dnn.Dense(num_hidden, 1)
+#
+# def update(self, x):
+# return self.out(self.rnn(x))
+#
+# # define the loss function
+# def lossfunc(inputs, targets):
+# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
+# predicts = runner.predict(inputs)
+# loss = bp.losses.mean_squared_error(predicts, targets)
+# return loss
+#
+# model = RNN(1, 2)
+# data_x = bm.random.rand(1, 1000, 1)
+# data_y = data_x + bm.random.randn(1, 1000, 1)
+#
+# bp.reset_state(model, 1)
+# losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
+# hess_matrix = losshess(data_x, data_y)
+#
+# weights = model.train_vars().unique()
+#
+# # define the loss function
+# def loss_func_for_jax(weight_vals, inputs, targets):
+# for k, v in weight_vals.items():
+# weights[k].value = v
+# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
+# predicts = runner.predict(inputs)
+# loss = bp.losses.mean_squared_error(predicts, targets)
+# return loss
+#
+# bp.reset_state(model, 1)
+# jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)
+#
+# for k, v in hess_matrix.items():
+# for kk, vv in v.items():
+# self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))
+#
+# bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py
index 4e1923e98..d2150d51d 100644
--- a/brainpy/_src/math/object_transform/tests/test_base.py
+++ b/brainpy/_src/math/object_transform/tests/test_base.py
@@ -239,10 +239,13 @@ def test1(self):
tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)
+ # tree = jax.tree.structure(hh)
+ # leaves = jax.tree.leaves(hh)
print(tree)
print(leaves)
print(jax.tree.unflatten(tree, leaves))
+ # print(jax.tree.unflatten(tree, leaves))
print()
@@ -282,12 +285,16 @@ def all_close(x, y):
assert bm.allclose(x, y)
jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
+ # jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
+ # random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
+ # jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
obj.load_state_dict(random_state)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
+ # jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
diff --git a/brainpy/_src/math/object_transform/tests/test_circular_reference.py b/brainpy/_src/math/object_transform/tests/test_circular_reference.py
index 61606d36e..8ef89dfca 100644
--- a/brainpy/_src/math/object_transform/tests/test_circular_reference.py
+++ b/brainpy/_src/math/object_transform/tests/test_circular_reference.py
@@ -65,7 +65,7 @@ def test_nodes():
A.pre = B
B.pre = A
- net = bp.dyn.Network(A, B)
+ net = bp.Network(A, B)
abs_nodes = net.nodes(method='absolute')
rel_nodes = net.nodes(method='relative')
print()
diff --git a/brainpy/_src/math/object_transform/tests/test_collector.py b/brainpy/_src/math/object_transform/tests/test_collector.py
index 9c3d5dde6..17ba00ec9 100644
--- a/brainpy/_src/math/object_transform/tests/test_collector.py
+++ b/brainpy/_src/math/object_transform/tests/test_collector.py
@@ -7,7 +7,7 @@
import brainpy as bp
-class GABAa_without_Variable(bp.TwoEndConn):
+class GABAa_without_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_without_Variable, self).__init__(pre=pre, post=post, **kwargs)
@@ -192,7 +192,7 @@ def test_neu_nodes_1():
assert len(neu.nodes(method='relative', include_self=False)) == 1
-class GABAa_with_Variable(bp.TwoEndConn):
+class GABAa_with_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_with_Variable, self).__init__(pre=pre, post=post, **kwargs)
diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py
index 7a04c2488..b48f75042 100644
--- a/brainpy/_src/math/object_transform/tests/test_controls.py
+++ b/brainpy/_src/math/object_transform/tests/test_controls.py
@@ -234,7 +234,7 @@ def f1():
branches=[f1,
lambda: 2, lambda: 3,
lambda: 4, lambda: 5],
- dyn_vars=var_a,
+ # dyn_vars=var_a,
show_code=True)
self.assertTrue(f(11) == 1)
diff --git a/brainpy/_src/math/object_transform/tests/test_jit.py b/brainpy/_src/math/object_transform/tests/test_jit.py
index d52903d43..16d0301d4 100644
--- a/brainpy/_src/math/object_transform/tests/test_jit.py
+++ b/brainpy/_src/math/object_transform/tests/test_jit.py
@@ -157,7 +157,7 @@ class MyObj:
def __init__(self):
self.a = bm.Variable(bm.ones(2))
- @bm.cls_jit(static_argnums=1)
+ @bm.cls_jit(static_argnums=0)
def f(self, b, c):
self.a.value *= b
self.a.value /= c
diff --git a/brainpy/_src/math/object_transform/tests/test_naming.py b/brainpy/_src/math/object_transform/tests/test_naming.py
new file mode 100644
index 000000000..bceee561e
--- /dev/null
+++ b/brainpy/_src/math/object_transform/tests/test_naming.py
@@ -0,0 +1,31 @@
+# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import brainpy as bp
+import brainpy.math as bm
+import unittest
+
+
+class TestNaming(unittest.TestCase):
+
+ def test_clear_name_cache(self):
+ lif = bp.dyn.LifRef(1, name='a')
+ with self.assertRaises(bp.errors.UniqueNameError):
+ lif = bp.dyn.LifRef(1, name='a')
+ bm.clear_name_cache(ignore_warn=True)
+ lif = bp.dyn.LifRef(1, name='a')
+ bm.clear_name_cache()
+ bm.clear_buffer_memory(array=True, compilation=True)
diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py
deleted file mode 100644
index 21c222c00..000000000
--- a/brainpy/_src/math/op_register/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .numba_approach import (CustomOpByNumba,
- register_op_with_numba,
- compile_cpu_signature_with_numba)
-from .base import XLACustomOp
-from .utils import register_general_batching
-from .taichi_aot_based import clear_taichi_aot_caches, count_taichi_aot_kernels
-from .base import XLACustomOp
-from .utils import register_general_batching
diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py
deleted file mode 100644
index 342093ea2..000000000
--- a/brainpy/_src/math/op_register/ad_support.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import functools
-from functools import partial
-
-from jax import tree_util
-from jax.core import Primitive
-from jax.interpreters import ad
-
-__all__ = [
- 'defjvp',
-]
-
-
-def defjvp(primitive, *jvp_rules):
- """Define JVP rules for any JAX primitive.
-
- This function is similar to ``jax.interpreters.ad.defjvp``.
- However, the JAX one only supports primitive with ``multiple_results=False``.
- ``brainpy.math.defjvp`` enables to define the independent JVP rule for
- each input parameter no matter ``multiple_results=False/True``.
-
- For examples, please see ``test_ad_support.py``.
-
- Args:
- primitive: Primitive, XLACustomOp.
- *jvp_rules: The JVP translation rule for each primal.
- """
- assert isinstance(primitive, Primitive)
- if primitive.multiple_results:
- ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
- else:
- ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
-
-
-def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
- assert primitive.multiple_results
- val_out = tuple(primitive.bind(*primals, **params))
- tree = tree_util.tree_structure(val_out)
- tangents_out = []
- for rule, t in zip(jvp_rules, tangents):
- if rule is not None and type(t) is not ad.Zero:
- r = tuple(rule(t, *primals, **params))
- tangents_out.append(r)
- assert tree_util.tree_structure(r) == tree
- return val_out, functools.reduce(_add_tangents,
- tangents_out,
- tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
-
-
-def _add_tangents(xs, ys):
- return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
-
diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py
deleted file mode 100644
index 5af5a7e3f..000000000
--- a/brainpy/_src/math/op_register/base.py
+++ /dev/null
@@ -1,223 +0,0 @@
-from functools import partial
-from typing import Callable, Sequence, Tuple, Protocol, Optional, Union
-
-import jax
-import numpy as np
-from jax.interpreters import xla, batching, ad, mlir
-
-from brainpy._src.dependency_check import import_numba, import_cupy_jit
-from brainpy._src.math.ndarray import Array
-from brainpy._src.math.object_transform.base import BrainPyObject
-
-if jax.__version__ >= '0.4.16':
- from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
- from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
- register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
- from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
- register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
-else:
- from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
- from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
- register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
- from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
- register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
-from .utils import register_general_batching
-from brainpy._src.math.op_register.ad_support import defjvp
-
-numba = import_numba(error_if_not_found=False)
-cp_jit = import_cupy_jit(error_if_not_found=False)
-
-__all__ = [
- 'XLACustomOp',
-]
-
-
-class ShapeDtype(Protocol):
-
- @property
- def shape(self) -> Tuple[int, ...]:
- ...
-
- @property
- def dtype(self) -> np.dtype:
- ...
-
-
-class XLACustomOp(BrainPyObject):
- """Creating a XLA custom call operator.
-
- For more information, please refer to the tutorials above:
- Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html
- Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html
- CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html
-
- Args:
- cpu_kernel: Callable. The function defines the computation on CPU backend.
- gpu_kernel: Callable. The function defines the computation on GPU backend.
- batching_translation: Callable. The batching translation rule of JAX.
- jvp_translation: Callable. The JVP translation rule of JAX.
- transpose_translation: Callable. The transpose translation rule of JAX.
- outs: optional. The output information.
- name: str. The primitive name.
- """
-
- def __init__(
- self,
- cpu_kernel: Callable = None,
- gpu_kernel: Union[Callable, str] = None,
- batching_translation: Callable = None,
- jvp_translation: Callable = None,
- transpose_translation: Callable = None,
- outs: Optional[Callable] = None,
- name: str = None,
- ):
- super().__init__(name)
-
- # set cpu_kernel and gpu_kernel
- self.cpu_kernel = cpu_kernel
- self.gpu_kernel = gpu_kernel
-
- # primitive
- self.primitive = jax.core.Primitive(self.name)
- self.primitive.multiple_results = True
-
- # abstract evaluation
- self.outs = outs
- self.primitive.def_abstract_eval(_abstract_eval)
- self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
-
- # cpu function
- cpu_checked = False
- if cpu_kernel is None:
- cpu_checked = True
- if numba is not None: # numba
- from numba.core.dispatcher import Dispatcher
- if isinstance(cpu_kernel, Dispatcher):
- register_numba_cpu_translation_rule(self.primitive, cpu_kernel)
- cpu_checked = True
- if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi
- register_taichi_cpu_translation_rule(self.primitive, cpu_kernel)
- cpu_checked = True
- if not cpu_checked:
- raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. '
- f'But we got {cpu_kernel}')
-
- # gpu function
- gpu_checked = False
- if gpu_kernel is None:
- gpu_checked = True
- elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule
- register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel)
- gpu_checked = True
- elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel
- register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel)
- gpu_checked = True
- elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
- register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
- gpu_checked = True
- if not gpu_checked:
- raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')
-
- # batching rule
- if batching_translation is None:
- register_general_batching(self.primitive)
- else:
- batching.primitive_batchers[self.primitive] = batching_translation
-
- # jvp rule
- if jvp_translation is not None:
- ad.primitive_jvps[self.primitive] = jvp_translation
-
- # transpose rule
- if transpose_translation is not None:
- ad.primitive_transposes[self.primitive] = transpose_translation
-
- def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
- if outs is None:
- if self.outs is None:
- raise ValueError('The output information is not defined.')
- outs = self.outs(*ins, **kwargs)
- assert outs is not None
- outs = tuple([_transform_to_shapedarray(o) for o in outs])
- ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
- return self.primitive.bind(*ins, outs=outs, **kwargs)
-
- def def_abstract_eval(self, fun):
- """Define the abstract evaluation function.
-
- Args:
- fun: The abstract evaluation function.
- """
- self.primitive.def_abstract_eval(fun)
-
- def def_batching_rule(self, fun):
- """Define the batching rule.
-
- Args:
- fun: The batching rule.
- """
- batching.primitive_batchers[self.primitive] = fun
-
- def def_jvp_rule(self, fun):
- """Define the JVP rule.
-
- Args:
- fun: The JVP rule.
- """
- ad.primitive_jvps[self.primitive] = fun
-
- def defjvp(self, *jvp_rules):
- """Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``, but supports the Primitive with multiple results.
-
- Args:
- jvp_rules: The JVP rules.
- """
- defjvp(self.primitive, *jvp_rules)
-
- def def_transpose_rule(self, fun):
- """Define the transpose rule.
-
- Args:
- fun: The transpose rule.
- """
- ad.primitive_transposes[self.primitive] = fun
-
- def def_xla_translation(self, platform, fun):
- """Define the XLA translation rule.
-
- Args:
- platform: str. The computing platform.
- fun: The XLA translation rule.
- """
- xla.backend_specific_translations[platform][self.primitive] = fun
-
- def def_mlir_lowering(self, platform, fun):
- """Define the MLIR lowering rule.
-
- Args:
- platform: str. The computing platform.
- fun: The lowering rule.
- """
- mlir.register_lowering(self.primitive, fun, platform)
-
-
-def _abstract_eval(*args, **kwargs):
- return [jax.core.ShapedArray(out_shape.shape, out_shape.dtype)
- for out_shape in kwargs['outs']]
-
-
-def _is_bp_array(a):
- return isinstance(a, Array)
-
-
-def _transform_to_array(a):
- if isinstance(a, Array):
- return a.value
- elif isinstance(a, jax.Array):
- return a
- else:
- return jax.numpy.asarray(a)
-
-
-def _transform_to_shapedarray(a):
- return jax.core.ShapedArray(a.shape, a.dtype)
diff --git a/brainpy/_src/math/op_register/cupy_based.py b/brainpy/_src/math/op_register/cupy_based.py
deleted file mode 100644
index ad6befecf..000000000
--- a/brainpy/_src/math/op_register/cupy_based.py
+++ /dev/null
@@ -1,279 +0,0 @@
-from functools import partial, reduce
-from typing import List, Tuple
-
-import jax
-import numpy as np
-from jax.interpreters import xla, mlir
-from jax.lib import xla_client
-from jaxlib.hlo_helpers import custom_call
-
-from brainpy._src.dependency_check import (import_cupy,
- import_cupy_jit,
- import_brainpylib_gpu_ops)
-from brainpy._src.math.op_register.utils import _shape_to_layout
-from brainpy.errors import PackageMissingError
-
-cp = import_cupy(error_if_not_found=False)
-cp_jit = import_cupy_jit(error_if_not_found=False)
-
-# convert type to number
-type_number_map = {
- int: 0,
- float: 1,
- bool: 2,
- np.dtype('int32'): 0,
- np.dtype('float32'): 1,
- np.dtype('bool'): 2,
- np.dtype('uint8'): 3,
- np.dtype('uint16'): 4,
- np.dtype('uint32'): 5,
- np.dtype('uint64'): 6,
- np.dtype('int8'): 7,
- np.dtype('int16'): 8,
- np.dtype('int64'): 9,
- np.dtype('float16'): 10,
- np.dtype('float64'): 11,
-}
-
-
-def _preprocess_kernel_call_gpu(
- grid: Tuple[int],
- block: Tuple[int],
- func_ptr: int,
- shared_mem: int,
- *ins,
- outs: List[jax.ShapeDtypeStruct],
-):
- grid = (grid + (1, 1))[:3]
- block = (block + (1, 1))[:3]
- in_num = len(ins)
- out_num = len(outs)
- in_out_num = [in_num, out_num]
-
- out_type_list = [0] * out_num
- out_elem_count_list = [0] * out_num
-
- for i, value in enumerate(outs):
- out_type_list[i] = type_number_map[value.dtype]
- out_elem_count_list[i] = reduce(lambda x, y: x * y, value.shape)
-
- grid = ",".join(str(i) for i in grid)
- block = ",".join(str(i) for i in block)
- in_out_num_str = ",".join(str(i) for i in in_out_num)
- out_type_list_str = ",".join(str(i) for i in out_type_list)
- out_elem_count_list_str = ",".join(str(i) for i in out_elem_count_list)
-
- opaque = (bytes(str(func_ptr), encoding='utf-8') + b';' +
- bytes(str(shared_mem), encoding='utf-8') + b';' +
- bytes(in_out_num_str, encoding='utf-8') + b';' +
- bytes(grid, encoding='utf-8') + b';' +
- bytes(block, encoding='utf-8') + b';' +
- bytes(out_type_list_str, encoding='utf-8') + b';' +
- bytes(out_elem_count_list_str, encoding='utf-8') + b';')
- return opaque
-
-
-def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
- grid = kwargs.get('grid', None)
- block = kwargs.get('block', None)
- shared_mem = kwargs.get('shared_mem', 0)
- if grid is None or block is None:
- raise ValueError('The grid and block should be specified for the cupy kernel.')
-
- # preprocess
- import_brainpylib_gpu_ops()
- # THE KEY:
- # - using the kernel pointer at "kernel.kernel.ptr"
- opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
- # create custom call
- return xla_client.ops.CustomCallWithLayout(
- c,
- b'cupy_kernel_call_gpu',
- operands=ins,
- operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
- shape_with_layout=xla_client.Shape.tuple_shape(
- [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
- for value in kwargs['outs']]
- ),
- opaque=opaque,
- )
-
-
-def register_cupy_raw_module_xla_gpu_translation_rule(primitive, gpu_kernel):
- xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_raw_module_xla_gpu_translation_rule, gpu_kernel)
-
-
-def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
- grid = kwargs.get('grid', None)
- block = kwargs.get('block', None)
- shared_mem = kwargs.get('shared_mem', 0)
- if grid is None or block is None:
- raise ValueError('The grid and block should be specified for the cupy kernel.')
-
- # preprocess
- import_brainpylib_gpu_ops()
- opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
- input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
- result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
- output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out]
-
- return custom_call(
- call_target_name='cupy_kernel_call_gpu',
- operands=ins,
- operand_layouts=list(input_layouts),
- result_layouts=list(output_layouts),
- result_types=list(result_types),
- backend_config=opaque,
- has_side_effect=False,
- ).results
-
-
-def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel):
- if cp is None:
- raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')
-
- rule = partial(_cupy_raw_module_mlir_gpu_translation_rule, gpu_kernel)
- mlir.register_lowering(primitive, rule, platform='gpu')
-
-
-def _to_cupy_array_or_scalar(dtype, ndim):
- # THE KEY
- # - using the cupy jit compiler to get the type
- if ndim != 0:
- t = cp_jit._cuda_types.CArray(dtype=dtype,
- ndim=ndim,
- is_c_contiguous=True,
- index_32_bits=True)
- else:
- t = cp_jit._cuda_types.Scalar(dtype=dtype)
- return t
-
-
-def _compile_kernel_xla(kernel, in_types):
- # THE KEY
- # - get the kernel function from the cache
- device_id = cp.cuda.get_device_id()
- kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None))
-
- if kern is None:
- # THE KEY:
- # - compile the kernel function
- result = kernel._cached_codes.get(in_types)
- if result is None:
- result = cp_jit._compile.transpile(
- kernel._func,
- ['extern "C"', '__global__'],
- 'cuda',
- in_types,
- cp_jit._cuda_types.void,
- )
- kernel._cached_codes[in_types] = result
- fname = result.func_name
- enable_cg = result.enable_cooperative_groups
- options = result.options
- backend = result.backend
- if backend == 'nvcc':
- options += ('-DCUPY_JIT_NVCC',)
- jitify = result.jitify
- module = cp._core.core.compile_with_cache(
- source=result.code,
- options=options,
- backend=backend,
- jitify=jitify,
- )
- kern = module.get_function(fname)
- kernel._cache[(in_types, device_id)] = (kern, enable_cg)
-
- return kern
-
-
-def get_jit_kernel_xla(kernel, c, *ins, outs):
- # get the input types
- in_types = []
- for x in ins:
- x = c.get_shape(x)
- in_types.append(_to_cupy_array_or_scalar(x.element_type(), len(x.dimensions())))
- for x in outs:
- in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
- in_types = tuple(in_types)
- # compile the kernel
- return _compile_kernel_xla(kernel, in_types)
-
-
-def get_jit_kernel_mlir(kernel, c):
- # get the input types
- in_types = []
- for x in c.avals_in:
- in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
- for x in c.avals_out:
- in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
- in_types = tuple(in_types)
- # compile the kernel
- return _compile_kernel_xla(kernel, in_types)
-
-
-def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
- kernel_func = get_jit_kernel_xla(kernel, c, *ins, outs=kwargs['outs'])
- grid = kwargs.get('grid', None)
- block = kwargs.get('block', None)
- shared_mem = kwargs.get('shared_mem', 0)
- if grid is None or block is None:
- raise ValueError('The grid and block should be specified for the cupy kernel.')
-
- # preprocess
- import_brainpylib_gpu_ops()
- opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
- # create custom call
- return xla_client.ops.CustomCallWithLayout(
- c,
- b'cupy_kernel_call_gpu',
- operands=ins,
- operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
- shape_with_layout=xla_client.Shape.tuple_shape(
- [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
- for value in kwargs['outs']]
- ),
- opaque=opaque,
- )
-
-
-def register_cupy_jit_kernel_xla_gpu_translation_rule(primitive, gpu_kernel):
- xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_jit_kernel_xla_gpu_translation_rule, gpu_kernel)
-
-
-def _cupy_jit_kernel_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
- kernel_func = get_jit_kernel_mlir(kernel, c)
- grid = kwargs.get('grid', None)
- block = kwargs.get('block', None)
- shared_mem = kwargs.get('shared_mem', 0)
- if grid is None or block is None:
- raise ValueError('The grid and block should be specified for the cupy kernel.')
-
- # preprocess
- import_brainpylib_gpu_ops()
- opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
- input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
- result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
- output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out]
-
- return custom_call(
- call_target_name='cupy_kernel_call_gpu',
- operands=ins,
- operand_layouts=list(input_layouts),
- result_layouts=list(output_layouts),
- result_types=list(result_types),
- backend_config=opaque,
- has_side_effect=False,
- ).results
-
-
-def register_cupy_jit_kernel_mlir_gpu_translation_rule(primitive, gpu_kernel):
- if cp is None:
- raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')
-
- rule = partial(_cupy_jit_kernel_mlir_gpu_translation_rule, gpu_kernel)
- mlir.register_lowering(primitive, rule, platform='gpu')
diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py
deleted file mode 100644
index 5bbd04e0c..000000000
--- a/brainpy/_src/math/op_register/numba_approach/__init__.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from functools import partial
-from typing import Callable
-from typing import Union, Sequence
-
-import jax
-from jax.interpreters import xla, batching, ad
-from jax.tree_util import tree_map
-
-from brainpy._src.dependency_check import import_numba
-from brainpy._src.math.ndarray import Array
-from brainpy._src.math.object_transform.base import BrainPyObject
-from brainpy.errors import PackageMissingError
-from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba
-
-numba = import_numba(error_if_not_found=False)
-
-
-__all__ = [
- 'CustomOpByNumba',
- 'register_op_with_numba',
- 'compile_cpu_signature_with_numba',
-]
-
-
-class CustomOpByNumba(BrainPyObject):
- """Creating a XLA custom call operator with Numba JIT on CPU backend.
-
- Parameters
- ----------
- name: str
- The name of operator.
- eval_shape: callable
- The function to evaluate the shape and dtype of the output according to the input.
- This function should receive the abstract information of inputs, and return the
- abstract information of the outputs. For example:
-
- >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...):
- >>> return out1_info, out2_info
- con_compute: callable
- The function to make the concrete computation. This function receives inputs,
- and returns outputs. For example:
-
- >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...):
- >>> pass
- """
-
- def __init__(
- self,
- eval_shape: Callable = None,
- con_compute: Callable = None,
- name: str = None,
- batching_translation: Callable = None,
- jvp_translation: Callable = None,
- transpose_translation: Callable = None,
- multiple_results: bool = True,
- ):
- super().__init__(name=name)
-
- # abstract evaluation function
- if eval_shape is None:
- raise ValueError('Must provide "eval_shape" for abstract evaluation.')
-
- # cpu function
- cpu_func = con_compute
-
- # register OP
- self.op = register_op_with_numba(
- self.name,
- cpu_func=cpu_func,
- out_shapes=eval_shape,
- batching_translation=batching_translation,
- jvp_translation=jvp_translation,
- transpose_translation=transpose_translation,
- multiple_results=multiple_results,
- )
-
- def __call__(self, *args, **kwargs):
- args = tree_map(lambda a: a.value if isinstance(a, Array) else a,
- args, is_leaf=lambda a: isinstance(a, Array))
- kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a,
- kwargs, is_leaf=lambda a: isinstance(a, Array))
- res = self.op.bind(*args, **kwargs)
- return res
-
-
-def register_op_with_numba(
- op_name: str,
- cpu_func: Callable,
- out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]],
- gpu_func_translation: Callable = None,
- batching_translation: Callable = None,
- jvp_translation: Callable = None,
- transpose_translation: Callable = None,
- multiple_results: bool = False,
-):
- """
- Converting the numba-jitted function in a Jax/XLA compatible primitive.
-
- Parameters
- ----------
- op_name: str
- Name of the operators.
-
- cpu_func: Callable
- A callable numba-jitted function or pure function (can be lambda function) running on CPU.
-
- out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None
- Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or
- a sequence of `ShapedArray`. If it is a function, it takes as input the argument
- shapes and dtypes and should return correct output shapes of `ShapedArray`.
-
- gpu_func_translation: Callable
- A callable cuda-jitted kernel running on GPU.
-
- batching_translation: Callable
- The batching translation for the primitive.
-
- jvp_translation: Callable
- The forward autodiff translation rule.
-
- transpose_translation: Callable
- The backward autodiff translation rule.
-
- multiple_results: bool
- Whether the primitive returns multiple results. Default is False.
-
- Returns
- -------
- op: core.Primitive
- A JAX Primitive object.
- """
-
- if jax.__version__ > '0.4.23':
- raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are '
- f'only supported in JAX version <= 0.4.23. \n'
- f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. '
- f'For more information, please refer to the documentation: '
- f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.')
-
- if numba is None:
- raise PackageMissingError.by_purpose('numba', 'custom op with numba')
-
- if out_shapes is None:
- raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or '
- 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument '
- 'shapes and dtypes and should return correct output shapes of `ShapedArray`.')
-
- prim = jax.core.Primitive(op_name)
- prim.multiple_results = multiple_results
-
- # user defined function
- from numba.core.dispatcher import Dispatcher
- if not isinstance(cpu_func, Dispatcher):
- cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func)
-
- # output shape evaluation function
- def abs_eval_rule(*input_shapes, **info):
- if callable(out_shapes):
- shapes = out_shapes(*input_shapes, **info)
- else:
- shapes = out_shapes
-
- if isinstance(shapes, jax.core.ShapedArray):
- assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data."
- elif isinstance(shapes, (tuple, list)):
- assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data."
- for elem in shapes:
- if not isinstance(elem, jax.core.ShapedArray):
- raise ValueError(f'Elements in "out_shapes" must be instances of '
- f'jax.abstract_arrays.ShapedArray, but we got '
- f'{type(elem)}: {elem}')
- else:
- raise ValueError(f'Unknown type {type(shapes)}, only '
- f'supports function, ShapedArray or '
- f'list/tuple of ShapedArray.')
- return shapes
-
- # cpu function
- prim.def_abstract_eval(abs_eval_rule)
- prim.def_impl(partial(xla.apply_primitive, prim))
- xla.backend_specific_translations['cpu'][prim] = partial(_cpu_translation,
- cpu_func,
- abs_eval_rule,
- multiple_results)
-
- # gpu function
- if gpu_func_translation is not None:
- xla.backend_specific_translations['gpu'][prim] = gpu_func_translation
-
- # batching
- if batching_translation is not None:
- batching.primitive_batchers[prim] = batching_translation
-
- # jvp
- if jvp_translation is not None:
- ad.primitive_jvps[prim] = jvp_translation
-
- # transpose
- if transpose_translation is not None:
- ad.primitive_transposes[prim] = transpose_translation
-
- return prim
diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py
deleted file mode 100644
index 4b06effdf..000000000
--- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py
+++ /dev/null
@@ -1,152 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import ctypes
-
-from jax import dtypes, numpy as jnp
-from jax.core import ShapedArray
-from jax.lib import xla_client
-
-from brainpy._src.dependency_check import import_numba
-
-numba = import_numba(error_if_not_found=False)
-ctypes.pythonapi.PyCapsule_New.argtypes = [
- ctypes.c_void_p, # void* pointer
- ctypes.c_char_p, # const char *name
- ctypes.c_void_p, # PyCapsule_Destructor destructor
-]
-ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
-
-__all__ = [
- '_cpu_translation',
- 'compile_cpu_signature_with_numba',
-]
-
-if numba is not None:
- from numba import types, carray, cfunc
-
-
-def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info):
- target_name, inputs, input_shapes, xla_output_shapes = \
- compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info)
- return xla_client.ops.CustomCallWithLayout(
- c,
- target_name,
- operands=inputs,
- operand_shapes_with_layout=input_shapes,
- shape_with_layout=xla_output_shapes,
- )
-
-
-def _cpu_signature(
- func,
- input_dtypes,
- input_shapes,
- output_dtypes,
- output_shapes,
- multiple_results: bool,
- debug: bool = False
-):
- code_scope = dict(
- func_to_call=func,
- input_shapes=input_shapes,
- input_dtypes=input_dtypes,
- output_shapes=output_shapes,
- output_dtypes=output_dtypes,
- carray=carray,
- )
-
- # inputs
- if len(input_shapes) > 1:
- args_in = [
- f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),'
- for i in range(len(input_shapes))
- ]
- args_in = '(\n ' + "\n ".join(args_in) + '\n )'
- else:
- args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])'
-
- # outputs
- if multiple_results:
- args_out = [
- f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),'
- for i in range(len(output_shapes))
- ]
- args_out = '(\n ' + "\n ".join(args_out) + '\n )'
- else:
- args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'
-
- # function body
- code_string = '''
-def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
- args_out = {args_out}
- args_in = {args_in}
- func_to_call(args_out, args_in)
- '''.format(args_in=args_in,
- args_out=args_out)
- if debug: print(code_string)
- exec(compile(code_string.strip(), '', 'exec'), code_scope)
-
- new_f = code_scope['xla_cpu_custom_call_target']
- if multiple_results:
- xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr),
- types.CPointer(types.voidptr)))(new_f)
- else:
- xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f)
- target_name = xla_c_rule.native_name.encode("ascii")
- capsule = ctypes.pythonapi.PyCapsule_New(
- xla_c_rule.address, # A CFFI pointer to a function
- b"xla._CUSTOM_CALL_TARGET", # A binary string
- None # PyCapsule object run at destruction
- )
- xla_client.register_custom_call_target(target_name, capsule, "cpu")
- return target_name
-
-
-def compile_cpu_signature_with_numba(
- c,
- func,
- abs_eval_fn,
- multiple_results,
- inputs: tuple,
- description: dict = None,
-):
- input_layouts = [c.get_shape(arg) for arg in inputs]
- info_inputs = []
- if description is None: description = dict()
- for v in description.values():
- if isinstance(v, (int, float)):
- input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ()))
- info_inputs.append(xla_client.ops.ConstantLiteral(c, v))
- elif isinstance(v, (tuple, list)):
- v = jnp.asarray(v)
- input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1))))
- info_inputs.append(xla_client.ops.Constant(c, v))
- else:
- raise TypeError
- input_layouts = tuple(input_layouts)
- input_dtypes = tuple(shape.element_type() for shape in input_layouts)
- input_dimensions = tuple(shape.dimensions() for shape in input_layouts)
- output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type())
- for shape in input_layouts[:len(inputs)]),
- **description)
- if isinstance(output_abstract_arrays, ShapedArray):
- output_abstract_arrays = (output_abstract_arrays,)
- assert not multiple_results
- else:
- assert multiple_results
- output_shapes = tuple(array.shape for array in output_abstract_arrays)
- output_dtypes = tuple(array.dtype for array in output_abstract_arrays)
- output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes)
- target_name = _cpu_signature(func,
- input_dtypes,
- input_dimensions,
- output_dtypes,
- output_shapes,
- multiple_results,
- debug=False)
- output_layouts = [xla_client.Shape.array_shape(*arg)
- for arg in zip(output_dtypes, output_shapes, output_layouts)]
- output_layouts = (xla_client.Shape.tuple_shape(output_layouts)
- if multiple_results else
- output_layouts[0])
- return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts
diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py
deleted file mode 100644
index f461f4277..000000000
--- a/brainpy/_src/math/op_register/numba_based.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import ctypes
-from functools import partial
-
-from jax.interpreters import xla, mlir
-from jax.lib import xla_client
-from jaxlib.hlo_helpers import custom_call
-
-from brainpy._src.dependency_check import import_numba
-from brainpy.errors import PackageMissingError
-from .utils import _shape_to_layout
-
-numba = import_numba(error_if_not_found=False)
-if numba is not None:
- from numba import types, carray, cfunc
-
-__all__ = [
- 'register_numba_xla_cpu_translation_rule',
- 'register_numba_mlir_cpu_translation_rule',
-]
-
-# [void* pointer,
-# const char *name,
-# PyCapsule_Destructor destructor]
-ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
-ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
-
-
-def _cpu_signature(
- kernel,
- input_dtypes,
- input_shapes,
- output_dtypes,
- output_shapes,
- debug: bool = False
-):
- code_scope = dict(
- func_to_call=kernel,
- input_shapes=input_shapes,
- input_dtypes=input_dtypes,
- output_shapes=output_shapes,
- output_dtypes=output_dtypes,
- carray=carray,
- )
-
- # inputs, outputs, arguments
- args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
- for i in range(len(input_shapes))]
- args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
- for i in range(len(output_shapes))]
- args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
-
- # function body
- code_string = '''
- def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
- {args_in}
- {args_out}
- func_to_call({args_call})
- '''.format(args_in="\n ".join(args_in),
- args_out="\n ".join(args_out),
- args_call=", ".join(args_call))
- if debug: print(code_string)
- exec(compile(code_string.strip(), '', 'exec'), code_scope)
-
- # register
- new_f = code_scope['xla_cpu_custom_call_target']
- xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)))(new_f)
- target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
- capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
- xla_client.register_custom_call_target(target_name, capsule, "cpu")
-
- return target_name
-
-
-def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs):
- outs = kwargs['outs']
-
- # output information
- output_shapes = tuple(out.shape for out in outs)
- output_dtypes = tuple(out.dtype for out in outs)
- output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes)
- output_infos = [xla_client.Shape.array_shape(*arg) for arg in zip(output_dtypes, output_shapes, output_layouts)]
- output_infos = xla_client.Shape.tuple_shape(output_infos)
-
- # input information
- input_layouts = tuple(c.get_shape(arg) for arg in ins)
- input_dtypes = tuple(inp.element_type() for inp in input_layouts)
- input_shapes = tuple(inp.dimensions() for inp in input_layouts)
-
- # compiling
- target_name = _cpu_signature(kernel,
- input_dtypes,
- input_shapes,
- output_dtypes,
- output_shapes,
- debug=debug)
-
- # call
- return xla_client.ops.CustomCallWithLayout(
- c,
- target_name.encode("ascii"),
- operands=tuple(ins),
- operand_shapes_with_layout=input_layouts,
- shape_with_layout=output_infos,
- )
-
-
-def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False):
- if numba is None:
- raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule')
-
- # do not support after jax >= 0.4.24
- xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule,
- cpu_kernel,
- debug)
-
-
-def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs):
- # output information
- outs = ctx.avals_out
- output_shapes = tuple([out.shape for out in outs])
- output_dtypes = tuple([out.dtype for out in outs])
- output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
- result_types = [mlir.aval_to_ir_type(out) for out in outs]
-
- # input information
- avals_in = ctx.avals_in
- input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
- input_dtypes = tuple(inp.dtype for inp in avals_in)
- input_shapes = tuple(inp.shape for inp in avals_in)
-
- # compiling function
- code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes,
- output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
- args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
- for i in range(len(input_shapes))]
- if len(output_shapes) > 1:
- args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
- for i in range(len(output_shapes))]
- sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
- else:
- args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
- sig = types.void(types.voidptr, types.CPointer(types.voidptr))
- args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
- code_string = '''
-def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
- {args_in}
- {args_out}
- func_to_call({args_call})
- '''.format(args_in="\n ".join(args_in),
- args_out="\n ".join(args_out),
- args_call=", ".join(args_call))
- if debug:
- print(code_string)
- exec(compile(code_string.strip(), '', 'exec'), code_scope)
- new_f = code_scope['numba_cpu_custom_call_target']
-
- # register
- xla_c_rule = cfunc(sig)(new_f)
- target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
- capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
- xla_client.register_custom_call_target(target_name, capsule, "cpu")
-
- # call
- return custom_call(
- call_target_name=target_name,
- operands=ins,
- operand_layouts=list(input_layouts),
- result_layouts=list(output_layouts),
- result_types=list(result_types),
- has_side_effect=False,
- ).results
-
-
-def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False):
- if numba is None:
- raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule')
-
- rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
- mlir.register_lowering(primitive, rule, platform='cpu')
diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py
deleted file mode 100644
index 858f338be..000000000
--- a/brainpy/_src/math/op_register/taichi_aot_based.py
+++ /dev/null
@@ -1,526 +0,0 @@
-import contextlib
-import hashlib
-import inspect
-import io
-import os
-import pathlib
-import platform
-import re
-import shutil
-from functools import partial, reduce
-from typing import Any, Sequence, Union
-
-import jax.core
-import numpy as np
-from jax.interpreters import xla, mlir
-from jax.lib import xla_client
-from jaxlib.hlo_helpers import custom_call
-
-from brainpy._src.dependency_check import (import_taichi,
- import_brainpylib_cpu_ops,
- import_brainpylib_gpu_ops)
-from brainpy.errors import PackageMissingError
-from .utils import _shape_to_layout
-
-
-taichi_cache_path = None
-
-
-# --- UTILS ###
-
-# get the path of home directory on Linux, Windows, Mac
-def get_home_dir():
- return str(pathlib.Path.home())
-
-
-# encode a string with md5
-def encode_md5(source: str) -> str:
- # create md5 object
- md5 = hashlib.md5()
-
- # encode source
- source_encode = source.encode(encoding='utf-8')
-
- # update md5 object
- md5.update(source_encode)
-
- return md5.hexdigest()
-
-
-# check kernels count
-def count_taichi_aot_kernels() -> int:
- """
- Count the number of AOT compiled kernels.
-
- Returns
- -------
- kernels_count: int
- The number of AOT compiled kernels.
-
- """
- if not os.path.exists(kernels_aot_path):
- return 0
- kernels_count = 0
- dir1 = os.listdir(kernels_aot_path)
- for i in dir1:
- dir2 = os.listdir(os.path.join(kernels_aot_path, i))
- kernels_count += len(dir2)
- return kernels_count
-
-
-def clear_taichi_aot_caches(kernels: Union[str, Sequence[str]] = None):
- """
- Clean the cache of the AOT compiled kernels.
-
- Parameters
- ----------
- kernels: str or list of str
- The name of the kernel to be cleaned. If None, all the kernels will be cleaned.
- """
- if kernels is None:
- global taichi_cache_path
- if taichi_cache_path is None:
- from taichi._lib.utils import import_ti_python_core
- taichi_cache_path = import_ti_python_core().get_repo_dir()
- # clean taichi cache
- if os.path.exists(taichi_cache_path):
- shutil.rmtree(taichi_cache_path)
- # clean brainpy-taichi AOT cache
- if os.path.exists(kernels_aot_path):
- shutil.rmtree(kernels_aot_path)
- return
- if isinstance(kernels, str):
- kernels = [kernels]
- if not isinstance(kernels, list):
- raise TypeError(f'kernels_name must be a list of str, but got {type(kernels)}')
- # clear brainpy kernel cache
- for kernel_name in kernels:
- if os.path.exists(os.path.join(kernels_aot_path, kernel_name)):
- shutil.rmtree(os.path.join(kernels_aot_path, kernel_name))
-
-
-# TODO
-# not a very good way
-# get source with dependencies
-def get_source_with_dependencies(func, visited=None):
- if visited is None:
- visited = set()
-
- source = inspect.getsource(func)
- if func in visited:
- return ''
-
- visited.add(func)
- module = inspect.getmodule(func)
- dependent_funcs = re.findall(r'(\w+)\(', source)
-
- for func_name in dependent_funcs:
- dependent_func = getattr(module, func_name, None)
- if callable(dependent_func):
- source += get_source_with_dependencies(dependent_func, visited)
- return source
-
-
-# check if Metal is supported
-def is_metal_supported():
- # first check if we are on macOS
- if platform.system() != 'Darwin':
- return False
- if platform.processor() != 'arm':
- return False
- return True
-
-
-# --- VARIABLES ###
-home_path = get_home_dir()
-kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels')
-is_metal_device = is_metal_supported()
-
-
-# check if a kernel exists in the database
-def _check_kernel_exist(source_md5_encode: str) -> bool:
- # get the realpath of the kernel
- kernel_path = os.path.join(kernels_aot_path, source_md5_encode)
-
- # check whether the kernel exists
- if os.path.exists(kernel_path):
- return True
- else:
- return False
-
-
-# --- KERNEL AOT BUILD ###
-
-
-def _array_to_field(dtype, shape) -> Any:
- ti = import_taichi()
- if dtype == np.bool_:
- dtype = bool
- elif dtype == np.int8:
- dtype = ti.int8
- elif dtype == np.int16:
- dtype = ti.int16
- elif dtype == np.int32:
- dtype = ti.int32
- elif dtype == np.int64:
- dtype = ti.int64
- elif dtype == np.uint8:
- dtype = ti.uint8
- elif dtype == np.uint16:
- dtype = ti.uint16
- elif dtype == np.uint32:
- dtype = ti.uint32
- elif dtype == np.uint64:
- dtype = ti.uint64
- elif dtype == np.float16:
- dtype = ti.float16
- elif dtype == np.float32:
- dtype = ti.float32
- elif dtype == np.float64:
- dtype = ti.float64
- else:
- raise NotImplementedError(f'Currently we do not support dtype {dtype} in Taichi. '
- f'If you think it is necessary, please open an issue at '
- f'https://github.com/brainpy/BrainPy/issues/new')
- return ti.field(dtype=dtype, shape=shape)
-
-
-# build aot kernel
-def _build_kernel(
- source_md5_encode: str,
- kernel: callable,
- ins: dict,
- outs: dict,
- device: str
-):
- ti = import_taichi()
-
- # init arch
- if device == 'cpu':
- if is_metal_device:
- arch = ti.arm64
- device = 'arm64'
- else:
- arch = ti.x64
- elif device == 'gpu':
- arch = ti.cuda
- else:
- raise ValueError(f'Unknown device: {device}')
- with contextlib.redirect_stdout(io.StringIO()):
- ti.init(arch=arch)
-
- # check arch is available
- if ti.lang.impl.current_cfg().arch != arch:
- raise RuntimeError(f"Arch {arch} is not available")
-
- # get kernel name
- kernel_name = kernel.__name__
-
- # replace the name of the func
- kernel.__name__ = f'taichi_kernel_{device}'
-
- # init template_args_dict
- template_args_dict = {}
- for key, value in ins.items():
- template_args_dict[key] = _array_to_field(value[0], value[1])
- for key, value in outs.items():
- template_args_dict[key] = _array_to_field(value[0], value[1])
-
- # make aot dir
- kernel_path = os.path.join(kernels_aot_path, source_md5_encode)
- os.makedirs(kernel_path, exist_ok=True)
-
- # compile kernel
- mod = ti.aot.Module(arch)
- mod.add_kernel(kernel, template_args=template_args_dict)
- mod.save(kernel_path)
-
- # rename kernel name
- kernel.__name__ = kernel_name
-
-
-# --- KERNEL CALL PREPROCESS ###
-
-# convert type to number
-type_number_map = {
- int: 0,
- float: 1,
- bool: 2,
- np.dtype('int32'): 0,
- np.dtype('float32'): 1,
- np.dtype('bool'): 2,
- np.dtype('uint8'): 3,
- np.dtype('uint16'): 4,
- np.dtype('uint32'): 5,
- np.dtype('uint64'): 6,
- np.dtype('int8'): 7,
- np.dtype('int16'): 8,
- np.dtype('int64'): 9,
- np.dtype('float16'): 10,
- np.dtype('float64'): 11,
-}
-
-
-# preprocess kernel call cpu
-def _preprocess_kernel_call_cpu(
- source_md5_encode: str,
- ins: Sequence,
- outs: Sequence,
-) -> list:
- in_out_info = []
- max_dim_count = 0
- for value in ins:
- if value.ndim > max_dim_count:
- max_dim_count = value.ndim
-
- for value in outs:
- if value.ndim > max_dim_count:
- max_dim_count = value.ndim
-
- # kernel_path
- kernel_path = os.path.join(kernels_aot_path, source_md5_encode)
- kernel_path = bytes(kernel_path, encoding='utf-8') + b'\0'
- kernel_path = np.array(list(kernel_path), dtype=np.uint8)
-
- # other args
- in_out_num = np.array([len(ins), len(outs), kernel_path.size], dtype=np.uint32)
- in_out_type_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32)
- in_out_dim_count_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32)
- in_out_elem_count_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32)
- in_out_shape_list = np.zeros((len(ins) + len(outs), max_dim_count), dtype=np.uint32)
-
- for i, value in enumerate(ins):
- in_out_type_list[i] = type_number_map[value.dtype]
- in_out_dim_count_list[i] = value.ndim
- in_out_elem_count_list[i] = value.size
- for j, dim in enumerate(value.shape):
- in_out_shape_list[i, j] = dim
-
- b = len(ins)
- for i, value in enumerate(outs):
- in_out_type_list[i + b] = type_number_map[value.dtype]
- in_out_dim_count_list[i + b] = value.ndim
- in_out_elem_count_list[i + b] = value.size
- for j, dim in enumerate(value.shape):
- in_out_shape_list[i + b, j] = dim
-
- in_out_info.append(in_out_num)
- in_out_info.append(in_out_type_list)
- in_out_info.append(in_out_dim_count_list)
- in_out_info.append(in_out_elem_count_list)
- in_out_info.append(in_out_shape_list)
- in_out_info.append(kernel_path)
-
- return in_out_info
-
-
-def _preprocess_kernel_call_gpu(
- source_md5_encode: str,
- ins: Sequence,
- outs: Sequence,
-) -> bytes:
- # if len(ins) + len(outs) > 8:
- # raise ValueError('The number of ins and outs must be less than 8!')
- kernel_path = os.path.join(kernels_aot_path, source_md5_encode)
-
- # other args
- param_total_num = len(ins) + len(outs)
- in_out_num = [len(ins), len(outs)]
- in_out_type_list = [0] * param_total_num
- in_out_dim_count_list = [0] * param_total_num
- in_out_elem_count_list = [0] * param_total_num
- in_out_shape_list = [0] * param_total_num * 8
-
- for i, value in enumerate(ins):
- in_out_type_list[i] = type_number_map[value.dtype]
- in_out_dim_count_list[i] = value.ndim
- in_out_elem_count_list[i] = value.size
- for j, dim in enumerate(value.shape):
- in_out_shape_list[i * 8 + j] = dim
-
- for i, value in enumerate(outs):
- in_out_type_list[i + len(ins)] = type_number_map[value.dtype]
- in_out_dim_count_list[i + len(ins)] = value.ndim
- in_out_elem_count_list[i + len(ins)] = value.size
- for j, dim in enumerate(value.shape):
- in_out_shape_list[(i + len(ins)) * 8 + j] = dim
-
- # covert to string
- in_out_num_str = ",".join(str(i) for i in in_out_num)
- in_out_type_list_str = ",".join(str(i) for i in in_out_type_list)
- in_out_dim_count_list_str = ",".join(str(i) for i in in_out_dim_count_list)
- in_out_elem_count_list_str = ",".join(str(i) for i in in_out_elem_count_list)
- in_out_shape_list_str = ",".join(str(i) for i in in_out_shape_list)
-
- opaque = (bytes(in_out_num_str, encoding='utf-8') + b';' +
- bytes(in_out_type_list_str, encoding='utf-8') + b';' +
- bytes(in_out_dim_count_list_str, encoding='utf-8') + b';' +
- bytes(in_out_elem_count_list_str, encoding='utf-8') + b';' +
- bytes(in_out_shape_list_str, encoding='utf-8') + b';' +
- bytes(kernel_path, encoding='utf-8'))
-
- return opaque
-
-
-def _XlaOp_to_ShapedArray(c, xla_op):
- xla_op = c.get_shape(xla_op)
- return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type())
-
-
-def _mlir_to_ShapedArray(c, op):
- return op
-
-
-def _kernel_to_code(kernel, abs_ins, abs_outs, platform):
- codes = f'[taichi {platform} kernel]\n' + get_source_with_dependencies(kernel)
- codes += '\n[ins]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_ins]))
- codes += '\n[outs]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_outs]))
- return codes
-
-
-def _compile_kernel(abs_ins, kernel, platform: str, **kwargs):
- # input and output abstract information
- abs_outs = kwargs['outs']
-
- # kernel to code
- codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform)
- source_md5_encode = os.path.join(kernel.__name__, encode_md5(codes))
-
- # create ins, outs dict from kernel's args
- in_num = len(abs_ins)
- names = tuple(inspect.signature(kernel).parameters.keys())
- in_names, out_names = names[:in_num], names[in_num:]
- ins_dict = {key: (abs_ins[i].dtype, abs_ins[i].shape) for i, key in enumerate(in_names)}
- outs_dict = {key: (abs_outs[i].dtype, abs_outs[i].shape) for i, key in enumerate(out_names)}
-
- # build kernels
- if not _check_kernel_exist(source_md5_encode): # TODO: more checking
- try:
- _build_kernel(source_md5_encode, kernel, ins_dict, outs_dict, platform)
- except Exception as e:
- try:
- os.removedirs(os.path.join(kernels_aot_path, source_md5_encode))
- except Exception:
- raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e
- raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e
-
- # returns
- if platform in ['gpu', 'cuda']:
- import_brainpylib_gpu_ops()
- opaque = _preprocess_kernel_call_gpu(source_md5_encode, abs_ins, abs_outs)
- return opaque
- elif platform == 'cpu':
- import_brainpylib_cpu_ops()
- in_out_info = _preprocess_kernel_call_cpu(source_md5_encode, abs_ins, abs_outs)
- return in_out_info
- else:
- raise ValueError(f'Unknown platform: {platform}')
-
-
-def _get_abs_ins(c, ins):
- abs_ins = []
- for v in ins:
- xla_op = c.get_shape(v)
- abs_ins.append(jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type()))
- return abs_ins
-
-
-def _taichi_xla_cpu_translation_rule(kernel, c, *ins, **kwargs):
- in_out_info = _compile_kernel(_get_abs_ins(c, ins), kernel, 'cpu', **kwargs)
- ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins)
- if is_metal_device:
- fn = b'taichi_kernel_aot_call_cpu_arm64'
- else:
- fn = b'taichi_kernel_aot_call_cpu'
-
- return xla_client.ops.CustomCallWithLayout(
- c,
- fn,
- operands=ins,
- operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
- shape_with_layout=xla_client.Shape.tuple_shape(
- [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
- for value in kwargs['outs']]
- ),
- )
-
-
-def _taichi_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
- opaque = _compile_kernel(_get_abs_ins(c, ins), kernel, 'gpu', **kwargs)
- return xla_client.ops.CustomCallWithLayout(
- c,
- b'taichi_kernel_aot_call_gpu',
- operands=ins,
- operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
- shape_with_layout=xla_client.Shape.tuple_shape(
- [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
- for value in kwargs['outs']]
- ),
- opaque=opaque,
- )
-
-
-def register_taichi_aot_xla_cpu_translation_rule(primitive, cpu_kernel):
- xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_xla_cpu_translation_rule, cpu_kernel)
-
-
-def register_taichi_aot_xla_gpu_translation_rule(primitive, gpu_kernel):
- xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_xla_gpu_translation_rule, gpu_kernel)
-
-
-def _taichi_mlir_cpu_translation_rule(kernel, c, *ins, **kwargs):
- in_out_info = _compile_kernel(c.avals_in, kernel, 'cpu', **kwargs)
- ins = [mlir.ir_constant(v) for v in in_out_info] + list(ins)
- input_layouts = [_shape_to_layout(arr.shape) for arr in in_out_info] + [_shape_to_layout(a.shape) for a in c.avals_in]
- output_layouts = tuple([_shape_to_layout(out.shape) for out in c.avals_out])
- result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
- if is_metal_device:
- if len(output_layouts) == 1:
- fn = 'taichi_kernel_aot_call_cpu_arm64_single_result'
- else:
- fn = 'taichi_kernel_aot_call_cpu_arm64'
- else:
- if len(output_layouts) == 1:
- fn = 'taichi_kernel_aot_call_cpu_single_result'
- else:
- fn = 'taichi_kernel_aot_call_cpu'
- return custom_call(
- call_target_name=fn,
- operands=ins,
- operand_layouts=list(input_layouts),
- result_layouts=list(output_layouts),
- result_types=list(result_types),
- has_side_effect=False,
- ).results
-
-
-def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
- opaque = _compile_kernel(c.avals_in, kernel, 'gpu', **kwargs)
- input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
- result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
- output_layouts = [_shape_to_layout(out.shape) for out in c.avals_out]
- return custom_call(
- call_target_name='taichi_kernel_aot_call_gpu',
- operands=ins,
- operand_layouts=list(input_layouts),
- result_layouts=list(output_layouts),
- result_types=list(result_types),
- backend_config=opaque,
- has_side_effect=False,
- ).results
-
-
-def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel):
- if import_taichi(error_if_not_found=False) is None:
- raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule')
-
- rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel)
- mlir.register_lowering(primitive, rule, platform='cpu')
-
-
-def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel):
- if import_taichi(error_if_not_found=False) is None:
- raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule')
-
- rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel)
- mlir.register_lowering(primitive, rule, platform='gpu')
diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py
deleted file mode 100644
index 2c9f09724..000000000
--- a/brainpy/_src/math/op_register/tests/test_ad_support.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import pytest
-from typing import Tuple
-
-import jax
-from jax import core
-from jax import numpy as jnp
-from jax.interpreters import ad
-
-import brainpy as bp
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_numba
-
-numba = import_numba(error_if_not_found=False)
-if numba is None:
- pytest.skip('no numba', allow_module_level=True)
-
-bm.set_platform('cpu')
-
-
-def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ):
- data = jnp.atleast_1d(bm.as_jax(data))
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- vector = bm.as_jax(vector)
- if vector.dtype == jnp.bool_:
- vector = bm.as_jax(vector, dtype=data.dtype)
- outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)]
- if transpose:
- return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
- else:
- return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
-
-
-@numba.njit(fastmath=True)
-def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val):
- res_val.fill(0)
- if values.shape[0] == 1:
- values = values[0]
- for row_i in range(vector.shape[0]):
- v = vector[row_i]
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- res_val[col_indices[j]] += values * v
- else:
- for row_i in range(vector.shape[0]):
- v = vector[row_i]
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- res_val[col_indices[j]] += v * values[j]
-
-
-@numba.njit(fastmath=True, parallel=True, nogil=True)
-def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val):
- res_val.fill(0)
- # csr mat @ vec
- if values.shape[0] == 1:
- values = values[0]
- for row_i in numba.prange(res_val.shape[0]):
- r = 0.
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- r += values * vector[col_indices[j]]
- res_val[row_i] = r
- else:
- for row_i in numba.prange(res_val.shape[0]):
- r = 0.
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- r += values[j] * vector[col_indices[j]]
- res_val[row_i] = r
-
-
-def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
- return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
-
-
-def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
- return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose)
-
-
-def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
-
- ct = ct[0]
- if ad.is_undefined_primal(vector):
- ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose)
- return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
-
- else:
- if type(ct) is ad.Zero:
- ct_data = ad.Zero(data)
- else:
- if data.aval.shape[0] == 1: # scalar
- ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
- ct_data = jnp.inner(ct, ct_data)
- else: # heterogeneous values
- row, col = bm.sparse.csr_to_coo(indices, indptr)
- ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
- return ct_data, indices, indptr, vector
-
-
-prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp)
-prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
-prim_trans.def_transpose_rule(_csrmv_cusparse_transpose)
-
-prim = bm.XLACustomOp(_csr_matvec_numba_imp)
-prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
-prim.def_transpose_rule(_csrmv_cusparse_transpose)
-
-
-def sum_op(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)
- return r.sum()
-
- return func
-
-
-def try_a_trial(transpose, shape):
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(0.1)
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- heter_data = rng.random(indices.shape)
- heter_data = bm.as_jax(heter_data)
- vector = rng.random(shape[0] if transpose else shape[1])
- vector = bm.as_jax(vector)
-
- r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs)), argnums=(0, 3))(
- heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))(
- heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- print(r5)
- print(r6)
- assert bm.allclose(r5[0], r6[0])
- assert bm.allclose(r5[1], r6[1][0])
-
-
-def test():
- transposes = [True, False]
- shapes = [(100, 200), (10, 1000), (2, 2000)]
-
- for transpose in transposes:
- for shape in shapes:
- try_a_trial(transpose, shape)
diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py
deleted file mode 100644
index 772b61607..000000000
--- a/brainpy/_src/math/op_register/tests/test_cupy_based.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import jax
-import pytest
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_cupy, import_cupy_jit, import_taichi
-
-cp = import_cupy(error_if_not_found=False)
-cp_jit = import_cupy_jit(error_if_not_found=False)
-ti = import_taichi(error_if_not_found=False)
-if cp is None or ti is None:
- pytest.skip('no cupy or taichi', allow_module_level=True)
-bm.set_platform('cpu')
-
-
-def test_cupy_based():
- bm.op_register.clear_taichi_aot_caches()
- # Raw Module
-
- @ti.kernel
- def simpleAdd(x1: ti.types.ndarray(ndim=2),
- x2: ti.types.ndarray(ndim=2),
- n: ti.types.ndarray(ndim=0),
- y: ti.types.ndarray(ndim=2)):
- for i, j in y:
- y[i, j] = x1[i, j] + x2[i, j]
-
- source_code = r'''
- extern "C"{
-
- __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)
- {
- unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
- if (tid < N)
- {
- y[tid] = x1[tid] + x2[tid];
- }
- }
- }
- '''
- N = 10
- x1 = bm.ones((N, N))
- x2 = bm.ones((N, N))
-
- mod = cp.RawModule(code=source_code)
- kernel = mod.get_function('kernel')
-
- prim1 = bm.XLACustomOp(cpu_kernel=simpleAdd, gpu_kernel=kernel)
-
- y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]
-
- print(y)
- assert bm.allclose(y, x1 + x2)
-
- # JIT Kernel
- @ti.kernel
- def elementwise_copy_taichi(x: ti.types.ndarray(ndim=1),
- size: ti.types.ndarray(ndim=1),
- y: ti.types.ndarray(ndim=1)):
- for i in y:
- y[i] = x[i]
-
- @cp_jit.rawkernel()
- def elementwise_copy(x, size, y):
- tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x
- ntid = cp_jit.gridDim.x * cp_jit.blockDim.x
- for i in range(tid, size, ntid):
- y[i] = x[i]
-
- size = 100
- x = bm.ones((size,))
-
- prim2 = bm.XLACustomOp(cpu_kernel=elementwise_copy_taichi, gpu_kernel=elementwise_copy)
-
- y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]
-
- print(y)
- assert bm.allclose(y, x)
-
-# test_cupy_based()
diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py
deleted file mode 100644
index 28b80d0f4..000000000
--- a/brainpy/_src/math/op_register/tests/test_numba_based.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import jax.core
-import pytest
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_numba
-
-numba = import_numba(error_if_not_found=False)
-if numba is None:
- pytest.skip('no numba', allow_module_level=True)
-
-bm.set_platform('cpu')
-
-
-@numba.njit(fastmath=True)
-def numba_event_csrmv(weight, indices, vector, outs):
- outs.fill(0)
- weight = weight[()] # 0d
- for row_i in range(vector.shape[0]):
- if vector[row_i]:
- for j in indices[row_i]:
- outs[j] += weight
-
-
-prim = bm.XLACustomOp(numba_event_csrmv)
-
-
-def call(s=100):
- indices = bm.random.randint(0, s, (s, 80))
- vector = bm.random.rand(s) < 0.1
- out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
- print(out[0].shape)
-
-
-def test_event_ELL():
- call(1000)
- call(100)
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py
deleted file mode 100644
index ea6dcadcf..000000000
--- a/brainpy/_src/math/op_register/tests/test_taichi_based.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import jax
-import jax.numpy as jnp
-import pytest
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-ti = import_taichi(error_if_not_found=False)
-if ti is None:
- pytest.skip('no taichi', allow_module_level=True)
-
-bm.set_platform('cpu')
-
-@ti.func
-def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:
- return weight[None]
-
-
-@ti.func
-def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
- out[index] += weight_val
-
-
-@ti.kernel
-def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
- vector: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=0),
- out: ti.types.ndarray(ndim=1)):
- weight_val = get_weight(weight)
- num_rows, num_cols = indices.shape
- ti.loop_config(serialize=True)
- for i in range(num_rows):
- if vector[i]:
- for j in range(num_cols):
- update_output(out, indices[i, j], weight_val)
-
-@ti.kernel
-def event_ell_gpu(indices: ti.types.ndarray(ndim=2),
- vector: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=0),
- out: ti.types.ndarray(ndim=1)):
- weight_val = get_weight(weight)
- num_rows, num_cols = indices.shape
- for i in range(num_rows):
- if vector[i]:
- for j in range(num_cols):
- update_output(out, indices[i, j], weight_val)
-
-prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)
-
-
-def test_taichi_op_register():
- s = 1000
- indices = bm.random.randint(0, s, (s, 1000))
- vector = bm.random.rand(s) < 0.1
-
- out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
-
- out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
-
- print(out)
-
-# test_taichi_op_register()
diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py
deleted file mode 100644
index b534435dc..000000000
--- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import jax
-import jax.numpy as jnp
-
-import brainpy.math as bm
-import taichi as ti
-
-from brainpy._src.dependency_check import import_taichi
-ti = import_taichi(error_if_not_found=False)
-if ti is None:
- import pytest
- pytest.skip('no taichi', allow_module_level=True)
-
-
-@ti.func
-def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
- return weight[0]
-
-
-@ti.func
-def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
- out[index] += weight_val
-
-
-@ti.kernel
-def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
- vector: ti.types.ndarray(ndim=1),
- weight: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- weight_val = get_weight(weight)
- num_rows, num_cols = indices.shape
- ti.loop_config(serialize=True)
- for i in range(num_rows):
- if vector[i]:
- for j in range(num_cols):
- update_output(out, indices[i, j], weight_val)
-
-
-prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
-
-
-def test_taichi_clean_cache():
- s = 1000
- indices = bm.random.randint(0, s, (s, 100))
- vector = bm.random.rand(s) < 0.1
- weight = bm.array([1.0])
-
- out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
-
- out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
-
- print(out)
- bm.clear_buffer_memory()
-
- print('kernels: ', bm.count_taichi_aot_kernels())
-
- bm.clear_taichi_aot_caches()
-
- print('kernels: ', bm.count_taichi_aot_kernels())
-
-# test_taichi_clean_cache()
diff --git a/brainpy/_src/math/op_register/utils.py b/brainpy/_src/math/op_register/utils.py
deleted file mode 100644
index 2a10443db..000000000
--- a/brainpy/_src/math/op_register/utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-from functools import partial
-
-import jax.numpy as jnp
-from jax import lax
-from jax.interpreters import batching
-from jax.tree_util import tree_flatten, tree_unflatten
-
-__all__ = [
- 'register_general_batching',
-]
-
-
-def _general_batching_rule(prim, args, axes, **kwargs):
- batch_axes, batch_args, non_batch_args = [], {}, {}
- for ax_i, ax in enumerate(axes):
- if ax is None:
- non_batch_args[f'ax{ax_i}'] = args[ax_i]
- else:
- batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0)
- batch_axes.append(ax_i)
-
- def f(_, x):
- pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
- for i in range(len(axes))])
- return 0, prim.bind(*pars, **kwargs)
-
- _, outs = lax.scan(f, 0, batch_args)
- out_vals, out_tree = tree_flatten(outs)
- out_dim = tree_unflatten(out_tree, (0,) * len(out_vals))
- return outs, out_dim
-
-
-def register_general_batching(prim):
- batching.primitive_batchers[prim] = partial(_general_batching_rule, prim)
-
-
-def _shape_to_layout(shape):
- return tuple(range(len(shape) - 1, -1, -1))
-
diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py
index 59588d3b9..776da1b5c 100644
--- a/brainpy/_src/math/others.py
+++ b/brainpy/_src/math/others.py
@@ -6,6 +6,7 @@
import jax
import jax.numpy as jnp
from jax.tree_util import tree_map
+import numpy as np
from brainpy import check, tools
from .compat_numpy import fill_diagonal
@@ -100,7 +101,8 @@ def false_f(x):
return (jnp.exp(x) - 1) / x
# return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x)
- return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x)
+ # return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x)
+ return jax.lax.select(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x)
def exprel(x, threshold: float = None):
diff --git a/brainpy/_src/math/pre_syn_post.py b/brainpy/_src/math/pre_syn_post.py
index bc9785692..06976a35b 100644
--- a/brainpy/_src/math/pre_syn_post.py
+++ b/brainpy/_src/math/pre_syn_post.py
@@ -56,7 +56,7 @@ def pre2post_event_sum(events,
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
- post_val[post_ids[i]] += values
+ post_val[post_ids[j]] += values
When ``values`` is a vector (with the length of ``len(post_ids)``),
this function is equivalent to
@@ -70,7 +70,7 @@ def pre2post_event_sum(events,
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
- post_val[post_ids[i]] += values[j]
+ post_val[post_ids[j]] += values[j]
Parameters
diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py
index 9ae012bc4..3f3a8446d 100644
--- a/brainpy/_src/math/random.py
+++ b/brainpy/_src/math/random.py
@@ -4,13 +4,14 @@
from collections import namedtuple
from functools import partial
from operator import index
-from typing import Optional, Union, Sequence
+from typing import Optional, Union, Sequence, Any
import jax
import numpy as np
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax._src.array import ArrayImpl
-from jax.experimental.host_callback import call
+from jax._src.core import _canonicalize_dimension, _invalid_shape_error
+from jax._src.typing import Shape
from jax.tree_util import register_pytree_node_class
from brainpy.check import jit_error_checking, jit_error_checking_no_args
@@ -34,7 +35,7 @@
'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal',
'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power',
'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min',
- 'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical',
+ 'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical', 'canonicalize_shape',
# pytorch compatibility
'rand_like', 'randint_like', 'randn_like',
@@ -66,10 +67,9 @@ def _size2shape(size):
def _check_shape(name, shape, *param_shapes):
- shape = core.as_named_shape(shape)
if param_shapes:
- shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
- if shape.positional != shape_:
+ shape_ = lax.broadcast_shapes(shape, *param_shapes)
+ if shape != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
@@ -438,6 +438,22 @@ def _check_py_seq(seq):
return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
+def canonicalize_shape(shape: Shape, context: str = "") -> tuple[Any, ...]:
+ """Canonicalizes and checks for errors in a user-provided shape value.
+
+ Args:
+ shape: a Python value that represents a shape.
+
+ Returns:
+ A tuple of canonical dimension values.
+ """
+ try:
+ return tuple(map(_canonicalize_dimension, shape))
+ except TypeError:
+ pass
+ raise _invalid_shape_error(shape, context)
+
+
@register_pytree_node_class
class RandomState(Variable):
"""RandomState that track the random generator state. """
@@ -1098,7 +1114,7 @@ def weibull_min(self, a, scale=None, size: Optional[Union[int, Sequence[int]]] =
def maxwell(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
- shape = core.canonicalize_shape(_size2shape(size)) + (3,)
+ shape = canonicalize_shape(_size2shape(size)) + (3,)
norm_rvs = jr.normal(key=key, shape=shape)
r = jnp.linalg.norm(norm_rvs, axis=-1)
return _return(r)
@@ -1233,9 +1249,9 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona
if size is None:
size = jnp.shape(a)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
- r = call(lambda x: np.random.zipf(x, size).astype(dtype),
- a,
- result_shape=jax.ShapeDtypeStruct(size, dtype))
+ r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
+ jax.ShapeDtypeStruct(size, dtype),
+ a)
return _return(r)
def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
@@ -1244,9 +1260,9 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option
size = jnp.shape(a)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
- r = call(lambda a: np.random.power(a=a, size=size).astype(dtype),
- a,
- result_shape=jax.ShapeDtypeStruct(size, dtype))
+ r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
+ jax.ShapeDtypeStruct(size, dtype),
+ a)
return _return(r)
def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
@@ -1260,11 +1276,11 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
- r = call(lambda x: np.random.f(dfnum=x['dfnum'],
- dfden=x['dfden'],
- size=size).astype(dtype),
- d,
- result_shape=jax.ShapeDtypeStruct(size, dtype))
+ r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'],
+ dfden=x['dfden'],
+ size=size).astype(dtype),
+ jax.ShapeDtypeStruct(size, dtype),
+ d)
return _return(r)
def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None,
@@ -1280,12 +1296,12 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
- r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'],
- nbad=d['nbad'],
- nsample=d['nsample'],
- size=size).astype(dtype),
- d,
- result_shape=jax.ShapeDtypeStruct(size, dtype))
+ r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
+ nbad=d['nbad'],
+ nsample=d['nsample'],
+ size=size).astype(dtype),
+ jax.ShapeDtypeStruct(size, dtype),
+ d)
return _return(r)
def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
@@ -1295,9 +1311,9 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
size = jnp.shape(p)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
- r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
- p,
- result_shape=jax.ShapeDtypeStruct(size, dtype))
+ r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
+ jax.ShapeDtypeStruct(size, dtype),
+ p)
return _return(r)
def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None,
@@ -1312,11 +1328,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
- r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
- dfden=x['dfden'],
- nonc=x['nonc'],
- size=size).astype(dtype),
- d, result_shape=jax.ShapeDtypeStruct(size, dtype))
+ r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
+ dfden=x['dfden'],
+ nonc=x['nonc'],
+ size=size).astype(dtype),
+ jax.ShapeDtypeStruct(size, dtype),
+ d)
return _return(r)
# PyTorch compatibility #
diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py
index 14256cbce..eec5f53c0 100644
--- a/brainpy/_src/math/sparse/__init__.py
+++ b/brainpy/_src/math/sparse/__init__.py
@@ -1,8 +1,7 @@
# from ._coo_mv import *
-# from ._bsr_mv import *
from .csr_mv import *
+from .csr_mm import *
from .utils import *
-from .bsr_mm import *
from .jax_prim import *
diff --git a/brainpy/_src/math/sparse/bsr_mm.py b/brainpy/_src/math/sparse/bsr_mm.py
deleted file mode 100644
index 19800749d..000000000
--- a/brainpy/_src/math/sparse/bsr_mm.py
+++ /dev/null
@@ -1,462 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from functools import partial
-from typing import Tuple
-
-import jax.lax
-import numpy as np
-from jax import numpy as jnp
-from jax.core import Primitive, ShapedArray
-from jax.interpreters import ad, xla
-from jax.lib import xla_client
-
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import (compile_cpu_signature_with_numba,
- register_general_batching)
-from brainpy.errors import GPUOperatorNotFound
-
-numba = import_numba(error_if_not_found=False)
-
-__all__ = [
- 'bcsrmm',
-]
-
-
-def get_mask(dense_b, blockshape, blockcount):
- mask = jnp.zeros(blockcount[0] * blockcount[1], dtype=jnp.bool_)
-
- for i in range(blockcount[1]):
- for j in range(blockcount[0]):
- if jnp.abs(dense_b[i * blockshape[1]: (i + 1) * blockshape[1],
- j * blockshape[0]: (j + 1) * blockshape[0]]).sum() != 0:
- mask = mask.at[i * blockcount[0] + j].set(True)
- mask = mask.reshape(blockcount[1], blockcount[0])
- return mask
-
-
-def get_mask_from_ptr_indices(ptr, indices, blockcount):
- mask = jnp.zeros((blockcount[1], blockcount[0]), dtype=jnp.bool_)
- for idx, indice in enumerate(indices):
- row_index = 0
- for ptr_ in ptr[1:]:
- if idx < ptr_:
- break
- row_index += 1
- mask = mask.at[row_index, indice].set(True)
- return mask
-
-
-def get_data(dense_b, mask, blockshape, blockcount, n_blocks):
- data = jnp.zeros(
- shape=(n_blocks * blockshape[1], blockshape[0]),
- dtype=jnp.float32
- )
-
- assignment_count = 0
- for i in range(blockcount[1]):
- for j in range(blockcount[0]):
- if mask[i, j]:
- data = data.at[assignment_count * blockshape[1]: (assignment_count + 1) * blockshape[1],
- :].set(dense_b[i * blockshape[1]: (i + 1) * blockshape[1],
- j * blockshape[0]: (j + 1) * blockshape[0]])
- assignment_count += 1
- return data
-
-
-def get_ptr_indices(mask, blockcount, n_blocks, block_ptr=None):
- nnz = jnp.nonzero(mask)
-
- if block_ptr is None:
- block_ptr = jnp.arange(0, len(nnz[0]))
-
- indices = jnp.argsort(block_ptr)
- _ = jnp.take(block_ptr, indices)
-
- blocks = nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)]
- blocks = jnp.stack([nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)]], axis=-1).astype(
- dtype=jnp.int32
- )
- blocks = jnp.flip(blocks, axis=-1).flatten()
-
- X = blockcount[1]
- Y = blockcount[0]
-
- rows = nnz[0][:]
- cols = nnz[1][:]
-
- block_indices = jnp.zeros(X * Y, dtype=jnp.int32)
- positions = rows * Y + cols
- block_indices = block_indices.at[positions].set(block_ptr + 1)
- block_indices = block_indices.reshape(X, Y).transpose().reshape(X * Y)
-
- block_ptr = block_indices[jnp.nonzero(block_indices)[0]] - 1
-
- X, Y = Y, X
- rows = cols
- nnztt = jnp.nonzero(mask.transpose())
- cols = nnztt[:][1]
-
- rows.astype(jnp.int32)
-
- ptr_b = jnp.zeros((X + 1,), dtype=jnp.int32)
- for row in rows:
- ptr_b = ptr_b.at[row + 1].set(ptr_b[row + 1] + 1)
- ptr_b = ptr_b.cumsum(0).astype(dtype=jnp.int32)
-
- indices_b = jnp.stack([cols, block_ptr], axis=1).astype(dtype=jnp.int32)
-
- return ptr_b, indices_b
-
-
-def get_dense(ptr, indices, data, shape, blockshape):
- mask = get_mask_from_ptr_indices(ptr, indices, blockshape)
- dense_data = jnp.zeros(shape, dtype=jnp.float32)
- mask_count = 0
- for i in range(mask.shape[1]):
- for j in range(mask.shape[0]):
- if mask[i, j]:
- dense_data = dense_data.at[
- i * blockshape[0]: (i + 1) * blockshape[0],
- j * blockshape[1]: (j + 1) * blockshape[1],
- ].set(data[mask_count * blockshape[0]: (mask_count + 1) * blockshape[0], :])
- mask_count += 1
- return dense_data
-
-
-def blocksparse_matmat_multiply(dense_a,
- ptr_b=None,
- indices_b=None,
- data_b=None,
- shape_b=None,
- dense_b=None,
- blockshape=(32, 32),
- device='cpu'):
- if dense_b is not None:
- # m, n, k
- m = dense_a.shape[0]
- k = dense_a.shape[1]
- n = dense_b.shape[1]
-
- # blockcount
- blockcount = (n // blockshape[0], k // blockshape[1])
-
- # mask
- mask = get_mask(dense_b, blockshape, blockcount)
-
- # n_blocks
- n_blocks = mask.sum()
-
- # data_b
- data_b = get_data(dense_b, mask, blockshape, blockcount, n_blocks)
-
- # ptr_b, indices_b
- ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks)
- else:
- # m, n, k
- m = dense_a.shape[0]
- n = shape_b[1]
- k = dense_a.shape[1]
-
- # blockcount
- blockcount = (n // blockshape[0], k // blockshape[1])
-
- mask = get_mask_from_ptr_indices(ptr_b, indices_b, blockcount)
-
- n_blocks = mask.sum()
-
- ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks)
-
- # out
- # out = jnp.zeros((n, m))
-
- # verbose
- print('data_b: ', data_b)
- print('ptr:', ptr_b)
- print('indices:', indices_b)
-
- '''out = blocksparse_matmat_cpu_test(dense_a,
- ptr_b,
- indices_b,
- data_b,
- out,
- m=m,
- n=n,
- k=k,
- block_size_k=blockshape[0],
- block_size_n=blockshape[1])
- return out'''
-
- if device == 'cpu':
- out = bcsrmm(
- dense_a,
- ptr_b,
- indices_b,
- data_b,
- m=m,
- n=n,
- k=k,
- block_size_k=blockshape[0],
- block_size_n=blockshape[1],
- )
- return out
- elif device == 'gpu':
- out = bcsrmm(
- dense_a,
- ptr_b,
- indices_b,
- data_b,
- m=m,
- n=n,
- k=k,
- block_size_k=blockshape[0],
- block_size_n=blockshape[1],
- )
- return out.transpose()
- else:
- raise Exception('Invalid device: ', device)
-
-
-def bcsrmm(
- A_data: jax.Array,
- B_data: jax.Array,
- B_indices: jax.Array,
- B_ptr: jax.Array,
- *,
- shape: Tuple[int, int],
- block_size: Tuple[int, int],
- transpose: bool = False,
- method: str = 'cutlass'
-) -> jax.Array:
- """Perform the matrix multiplication :math:`C = A @ B` with BSR data structure.
-
- Args:
- A_data: The dense matrix :math:`A`.
- B_data: The data at each block of :math:`B`.
- B_indices: The sparse indices of :math:`B`.
- B_ptr: The connection pointer of :math:`B`.
- shape: a tuple of int, indicating the array shape of :math:`B`.
- block_size: a tuple of int, indicating the block size for portioning :math:`B`.
- transpose: boolean. If True, perform :math:`A @ B^T`; otherwise, perform :math:`A @ B`.
- method: a sting for denoting the BSR sparse computing method.
-
- Returns:
- The dense array :math:`C`.
- """
- A_data = as_jax(A_data)
- B_data = as_jax(B_data)
- B_indices = as_jax(B_indices)
- B_ptr = as_jax(B_ptr)
- assert A_data.shape[1] == shape[0]
-
- if method == 'cutlass':
- C = _bcsrmm_cutlass_p.bind(A_data,
- B_data,
- B_indices,
- B_ptr,
- m=A_data.shape[0],
- k=shape[0],
- n=shape[1],
- transpose=transpose,
- block_size_k=block_size[0],
- block_size_n=block_size[1])[0]
- return C.T
- else:
- raise ValueError
-
-
-if numba is not None:
- @numba.njit(fastmath=True, parallel=True, nogil=True)
- def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m)
- res_val = outs[0]
- # B_data: (num_block, block_size_k, block_size_n)
- A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins
- block_size_k = block_size_k[()]
- block_size_n = block_size_n[()]
- n_block = n // block_size_n
-
- for ni in numba.prange(n_block):
- C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype)
- start, end = B_inptr[ni], B_inptr[ni + 1]
- ns = ni * block_size_n
- ne = ns + block_size_n
- for i in range(start, end):
- ki = B_indices[i, 0]
- ks = ki * block_size_k
- ke = ki + block_size_k
- bi = B_indices[i, 1]
- C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T)
- res_val[ns: ne] = C_tmp
- return res_val
-
-
- @numba.njit(fastmath=True, parallel=True, nogil=True)
- def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m)
- res_val = outs[0]
- # B_data: (num_block, block_size_n, block_size_k)
- A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins
- block_size_k = block_size_k[()]
- block_size_n = block_size_n[()]
- n_block = n // block_size_n
-
- for ni in numba.prange(n_block):
- C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype)
- start, end = B_inptr[ni], B_inptr[ni + 1]
- ns = ni * block_size_n
- ne = ns + block_size_n
- for i in range(start, end):
- ki = B_indices[i, 0]
- ks = ki * block_size_k
- ke = ki + block_size_k
- bi = B_indices[i, 1]
- C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T)
- res_val[ns: ne] = C_tmp
- return res_val
-
-
-def _bcsrmm_cutlass_abstract(
- A_data, B_data, B_indices, B_ptr, *, m, k, n, block_size_k, block_size_n
-):
- assert block_size_k == 32, 'cutlass based block-sparse mm only support block size (32, 32)'
- assert block_size_n == 32, 'cutlass based block-sparse mm only support block size (32, 32)'
- assert B_indices.shape[0] * block_size_n == B_data.shape[0]
- assert block_size_k == B_data.shape[1]
- assert A_data.shape[0] == m
- assert A_data.shape[1] == k
- assert A_data.dtype == B_data.dtype
- assert n // block_size_n + 1 == B_ptr.shape[0]
- return [ShapedArray(dtype=A_data.dtype, shape=(n, m))]
-
-
-def _bcsrmm_cutlass_cpu_translation(
- c, A_data, B_data, B_indices, B_ptr, *,
- m, k, n, block_size_k, block_size_n
-):
- inputs = (A_data, B_ptr, B_indices, B_data)
- description = dict(m=m,
- n=n,
- k=k,
- block_size_k=block_size_k,
- block_size_n=block_size_n)
- name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba(
- c,
- _bcsrmm_cutlass_imp2,
- abs_eval_fn=_bcsrmm_cutlass_abstract,
- multiple_results=True,
- inputs=inputs,
- description=description
- )
- return xla_client.ops.CustomCallWithLayout(
- c, name,
- operands=inputs,
- operand_shapes_with_layout=in_layouts,
- shape_with_layout=out_layouts,
- )
-
-
-def _bcsrmm_cutlass_gpu_translation(c, A_data, B_data, B_indices, B_ptr, *, m, k, n, block_size_k, block_size_n):
- gpu_ops = import_brainpylib_gpu_ops()
- if gpu_ops is None:
- raise GPUOperatorNotFound(_bcsrmm_cutlass_p.name)
-
- matrix_info = c.get_shape(A_data)
- dtype = matrix_info.element_type()
-
- opaque = gpu_ops.build_blocksparse_format_descriptor(m,
- n,
- k,
- block_size_k,
- block_size_n)
-
- fn = b'gpu_blocksparse_matmat'
-
- return xla_client.ops.CustomCallWithLayout(
- c,
- fn,
- operands=(A_data, B_ptr, B_indices, B_data,),
- operand_shapes_with_layout=(c.get_shape(A_data),
- c.get_shape(B_ptr),
- c.get_shape(B_indices),
- c.get_shape(B_data),),
- shape_with_layout=xla_client.Shape.tuple_shape(
- (xla_client.Shape.array_shape(dtype, (n, m), (1, 0)),)
- ),
- opaque=opaque
- )
-
-
-def _bcsrmm_cutlass_jvp_dense_a(dense_a_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k,
- block_size_n):
- return bcsrmm(dense_a_dot, B_ptr, B_indices, B_data, m=m, n=n, k=k, block_size_k=block_size_k,
- block_size_n=block_size_n)
-
-
-def _bcsrmm_cutlass_jvp_data_b(data_b_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k,
- block_size_n):
- return bcsrmm(A_data, B_ptr, B_indices, data_b_dot, m=m, n=n, k=k, block_size_k=block_size_k,
- block_size_n=block_size_n)
-
-
-def _bcsrmm_cutlass_jvp_transpose():
- # TODO: implement
- pass
-
-
-_bcsrmm_cutlass_p = Primitive('bcsrmm_cutlass_pim')
-_bcsrmm_cutlass_p.multiple_results = True
-_bcsrmm_cutlass_p.def_abstract_eval(_bcsrmm_cutlass_abstract)
-_bcsrmm_cutlass_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_p))
-# xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_cpu_translation
-# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_gpu_translation
-ad.primitive_jvps[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose
-ad.primitive_transposes[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose
-register_general_batching(bcsrmm)
-
-
-def _blocksparse_matmat_back_abstract(
- A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len
-):
- shape = (n, k)
- dtype = A_data.dtype
- out = ShapedArray(dtype=dtype, shape=shape)
- return [out]
-
-
-def _blocksparse_matmat_back_gpu_translation(
- c, A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len
-):
- gpu_ops = import_brainpylib_gpu_ops()
- if gpu_ops is None:
- raise GPUOperatorNotFound(_bcsrmm_cutlass_back_p.name)
- matrix_info = c.get_shape(A_data)
- dtype = matrix_info.element_type()
-
- opaque = gpu_ops.build_blocksparse_back_format_descriptor(m,
- n,
- k,
- block_size_k,
- block_size_n,
- blocks_len)
-
- fn = b'gpu_blocksparse_matmat_back'
-
- return xla_client.ops.CustomCallWithLayout(
- c,
- fn,
- operands=(A_data, B_data, blocks,),
- operand_shape_with_layout=(c.get_shape(A_data),
- c.get_shape(B_data),
- c.get_shape(blocks),),
- shape_with_layout=xla_client.Shape.tuple_shape(
- (xla_client.Shape.array_shape(dtype, (k, n), (1, 0)),)
- ),
- opaque=opaque
- )
-
-
-_bcsrmm_cutlass_back_p = Primitive('bcsrmm_cutlass_back_prim')
-_bcsrmm_cutlass_back_p.multiple_results = True
-_bcsrmm_cutlass_back_p.def_abstract_eval(_blocksparse_matmat_back_abstract)
-_bcsrmm_cutlass_back_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_back_p))
-# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation
-register_general_batching(_bcsrmm_cutlass_back_p)
diff --git a/brainpy/_src/math/sparse/bsr_mv.py b/brainpy/_src/math/sparse/bsr_mv.py
deleted file mode 100644
index 7dc0b683d..000000000
--- a/brainpy/_src/math/sparse/bsr_mv.py
+++ /dev/null
@@ -1,210 +0,0 @@
-from functools import partial
-from typing import Union, Tuple
-
-import numba
-import numpy as np
-from jax import numpy as jnp
-from jax.core import ShapedArray, Primitive
-from jax.interpreters import ad, xla
-from jax.lib import xla_client
-
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import (compile_cpu_signature_with_numba,
- register_general_batching)
-from brainpy._src.math.sparse.utils import csr_to_coo
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops
-from brainpy.errors import GPUOperatorNotFound
-
-__all__ = [
- 'cusparse_bcsr_matvec'
-]
-
-
-@numba.njit(fastmath=True, parallel=True, nogil=True)
-def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins):
- data, indices, indptr, vector, blocksize, shape, nnzb, transpose = ins
- blocksize = blocksize[()]
- outs.fill(0)
- for i in range(shape[0]):
- tmp = np.zeros(blocksize, dtype=data.dtype)
- for j in range(indptr[i], indptr[i + 1]):
- start = indices[j] * blocksize
- end = start + blocksize
- tmp += data[start: end] @ vector[start: end]
- outs[i * blocksize: (i + 1) * blocksize] = tmp
-
-
-# @numba.njit(fastmath=True, parallel=True, nogil=True)
-# def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins):
-# data, indices, indptr, vector, blocksize , shape,nnzb,transpose = ins
-# blocksize = blocksize[()]
-# outs.fill(0)
-
-# cnt=0
-# for i in range(0,shape[0]):
-# outs.fill(0.0)
-# tmp=[0.0]*blocksize
-# for j in range(indptr[i], indptr[i + 1]):
-# for p in range(0,blocksize):
-# for q in range(0,blocksize):
-# tmp[p] += vector[indices[j]*blocksize+q]*data[j*blocksize+p][q]
-# for j in range(0,blocksize):
-# outs[cnt] = tmp[j]
-# cnt+=1
-
-
-def _cusprase_bcsr_matvec_values(values, indices, indptr, vector, *, blocksize, nnzb, shape, transpose):
- return cusparse_bcsr_matvec(values, indices, indptr, vector, blocksize, nnzb=nnzb, shape=shape, transpose=transpose)
-
-
-def cusparse_bcsr_matvec(
- data: Union[float, jnp.ndarray],
- indices: jnp.ndarray,
- indptr: jnp.ndarray,
- vector: jnp.ndarray,
- *,
- blocksize: int,
- nnzb: int,
- shape: Tuple[int, int],
- method: str = 'vector',
- transpose: bool = False
-) -> jnp.ndarray:
- data = as_jax(data)
- indices = as_jax(indices)
- indptr = as_jax(indptr)
- vector = as_jax(vector)
- if method not in ['scalar', 'vector', 'adaptive']:
- raise ValueError('Only support methods: scalar, vector, and adaptive. '
- f'But we got {method}.')
-
- data = jnp.atleast_1d(data)
- if not isinstance(data, jnp.ndarray):
- raise TypeError(f'data must a ndarray. But we got {type(data)}')
- if data.dtype not in [jnp.float32, jnp.float64]:
- raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.')
- if data.dtype != vector.dtype:
- raise TypeError('The types of data and vector should be the same. '
- f'But we got {data.dtype} != {vector.dtype}.')
- # assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1
-
- return cusparse_bcsr_matvec_vector_p.bind(data, indices, indptr, vector, blocksize=blocksize, shape=shape, nnzb=nnzb,
- transpose=transpose)
-
-
-def _cusparse_bcsr_matvec_vector_cpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb,
- transpose):
- inputs = (data, indices, indptr, vector)
- print(c.get_shape(data))
- description = dict(blocksize=blocksize, shape=shape, nnzb=nnzb, transpose=transpose, )
- if transpose:
- skip = 1
- else:
- name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba(
- c,
- _cusparse_bcsr_matvec_bsr_matvec_numba_imp,
- abs_eval_fn=_cusparse_bcsr_matvec_abstract,
- multiple_results=False,
- inputs=inputs,
- description=description
- )
- return xla_client.ops.CustomCallWithLayout(
- c, name,
- operands=inputs,
- operand_shapes_with_layout=in_layouts,
- shape_with_layout=out_layouts,
- )
-
-
-def _cusparse_bcsr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb):
- gpu_ops = import_brainpylib_gpu_ops()
- if gpu_ops is None:
- raise GPUOperatorNotFound(cusparse_bcsr_matvec_vector_p.name)
-
- data_shape = c.get_shape(data)
- if data_shape.element_type() == np.float32:
- type_name = b'float'
- elif data_shape.element_type() == np.double:
- type_name = b'double'
- else:
- raise ValueError('data_type not support(except float/double)')
- # 有可能不是这个
-
- opaque = gpu_ops.build_bcsrcusparsespmv_descriptor(shape[0], shape[1], blocksize, nnzb)
- return xla_client.ops.CustomCallWithLayout(
- c,
- b'gpu_bcsr_cusparse_spmv_' + type_name,
- operands=(data, indices, indptr, vector),
- operand_shapes_with_layout=(c.get_shape(data),
- c.get_shape(indices),
- c.get_shape(indptr),
- c.get_shape(vector),
- ),
- shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0] * blocksize,), (0,)),
- opaque=opaque,
- )
-
-
-# def _bcsr_matvec_abstract(*args, **kwargs):
-# data = args[0]
-# assert len(kwargs) == 1
-# shape = kwargs['shape']
-# return ShapedArray(dtype=data.dtype, shape=(shape[0],))
-
-# bcsr_matvec_vector_p = register_op_with_numba(
-# 'bcsr_matvec_vector',
-# cpu_func=None,
-# out_shapes=_bcsr_matvec_abstract,
-# gpu_func_translation=_bcsr_matvec_vector_gpu_translation,
-# )
-
-
-# def _batch_bcsr_matvec_abstract(
-# values, indices, indptr, vector,block_size, *, shape, transpose=False
-# ):
-# return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0]))
-
-def _cusparse_bcsr_matvec_abstract(data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose=False):
- return ShapedArray(dtype=data.dtype, shape=(shape[0] * blocksize,))
-
-
-def _cusparse_bcsr_matvec_jvp_values(data_dot, data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose):
- return cusparse_bcsr_matvec(data_dot, indices, indptr, vector, blocksize=blocksize, nnzb=nnzb, shape=shape,
- transpose=transpose)
-
-
-def _cusparse_bcsr_transpose(ct, data, indices, indptr, vector, *, blocksize, shape, transpose):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
- if ad.is_undefined_primal(vector):
- ct_events = cusparse_bcsr_matvec(data, indices, indptr, ct, shape=shape, transpose=not transpose)
- return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_events)
- else:
- if type(ct) is ad.Zero:
- ct_values = ad.Zero(data)
- else:
- row, col = csr_to_coo(indices, indptr)
- cnt = 0
- ct_values = []
- for i in row:
- for j in col:
- for p in range(0, blocksize):
- cntq = 0
- for q in range(0, blocksize):
- if transpose:
- ct_values[cnt][cntq] = vector[i * blocksize + p] * ct[j * blocksize + q]
- else:
- ct_values[cnt][cntq] = vector[j * blocksize + q] * ct[i * blocksize + p]
- cntq += 1
- cnt += 1
- return ct_values, indices, indptr, vector
-
-
-cusparse_bcsr_matvec_vector_p = Primitive('cusparse_block_spmv')
-cusparse_bcsr_matvec_vector_p.def_abstract_eval(_cusparse_bcsr_matvec_abstract)
-cusparse_bcsr_matvec_vector_p.def_impl(partial(xla.apply_primitive, cusparse_bcsr_matvec_vector_p))
-# xla.backend_specific_translations['gpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_gpu_translation
-# xla.backend_specific_translations['cpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_cpu_translation
-ad.defjvp(cusparse_bcsr_matvec_vector_p, _cusparse_bcsr_matvec_jvp_values)
-ad.primitive_transposes[cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_transpose
-register_general_batching(cusparse_bcsr_matvec_vector_p)
-# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule
diff --git a/brainpy/_src/math/sparse/coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py
index 2885d9463..c9a46ff6d 100644
--- a/brainpy/_src/math/sparse/coo_mv.py
+++ b/brainpy/_src/math/sparse/coo_mv.py
@@ -1,18 +1,14 @@
# -*- coding: utf-8 -*-
-import warnings
-from functools import partial
from typing import Union, Tuple
-import numpy as np
-from jax import core, numpy as jnp, dtypes, default_backend
-from jax.interpreters import ad, mlir
-from jaxlib import gpu_sparse
+from jax import numpy as jnp
-from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
-from brainpy._src.math.op_register import register_general_batching
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
+
+bti = import_braintaichi(error_if_not_found=False)
__all__ = [
'coomv',
@@ -65,136 +61,17 @@ def coomv(
An array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
-
- data = jnp.atleast_1d(as_jax(data))
- row = as_jax(row)
- col = as_jax(col)
- vector = as_jax(vector)
-
- if method == 'cusparse':
- if default_backend() != 'cpu':
- if data.shape[0] == 1:
- data = jnp.ones(row.shape, dtype=data.dtype) * data
- if row.dtype in [jnp.uint32, jnp.uint64]:
- row = jnp.asarray(row, dtype=dtypes.canonicalize_dtype(jnp.int64))
- if col.dtype in [jnp.uint32, jnp.uint64]:
- col = jnp.asarray(col, dtype=dtypes.canonicalize_dtype(jnp.int64))
- return _coomv_cusparse_p.bind(data,
- row,
- col,
- vector,
- shape=shape,
- rows_sorted=rows_sorted,
- cols_sorted=cols_sorted,
- transpose=transpose)
-
- else:
- raise ValueError
-
-
-# --------------------------------------------------------------------
-# cusparse_coo_matvec
-
-
-def _coomv_impl(data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose):
- v = jnp.asarray(v)
- if transpose:
- row, col = col, row
- out_shape = shape[1] if transpose else shape[0]
- dv = data * v[col]
- return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)
-
-
-def _coomv_abstract_eval(data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose):
- assert data.shape == row.shape == col.shape
- assert data.dtype == v.dtype
- assert row.dtype == col.dtype
- assert len(shape) == 2
- assert v.ndim == 1
- assert v.shape[0] == (shape[0] if transpose else shape[1])
- out_shape = shape[1] if transpose else shape[0]
- return core.ShapedArray((out_shape,), data.dtype)
-
-
-_coo_matvec_lowering = mlir.lower_fun(_coomv_impl, multiple_results=False)
-
-
-def _coomv_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *,
- shape, rows_sorted, cols_sorted, transpose):
- data_aval, row_aval, _, x_aval = ctx.avals_in
- dtype = data_aval.dtype
- if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
- warnings.warn(f"cusparse_coo_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
- "Falling back to default implementation.", UserWarning)
- return _coo_matvec_lowering(ctx, data, row, col, v,
- shape=shape,
- rows_sorted=rows_sorted,
- cols_sorted=cols_sorted,
- transpose=transpose)
-
- if rows_sorted:
- shape = shape
- elif cols_sorted:
- row, col = col, row
- transpose = not transpose
- shape = shape[::-1]
- else:
- warnings.warn("cusparse_coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. "
- "To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
- "back to the default implementation.", UserWarning)
- return _coo_matvec_lowering(ctx, data, row, col, v,
- shape=shape,
- rows_sorted=rows_sorted,
- cols_sorted=cols_sorted,
- transpose=transpose)
-
- return [coo_matvec_mhlo(data, row, col, v,
- shape=shape,
- transpose=transpose,
- index_dtype=row_aval.dtype,
- data_dtype=dtype,
- x_dtype=x_aval.dtype)]
-
-
-def _coomv_jvp_mat(data_dot, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose):
- return _coomv_cusparse_p.bind(data_dot, row, col, v,
- shape=shape,
- rows_sorted=rows_sorted,
- cols_sorted=cols_sorted,
- transpose=transpose)
-
-
-def _coomv_jvp_vec(v_dot, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose):
- return _coomv_cusparse_p.bind(data, row, col, v_dot,
- shape=shape,
- rows_sorted=rows_sorted,
- cols_sorted=cols_sorted,
- transpose=transpose)
-
-
-def _coomv_transpose(ct, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose):
- assert not ad.is_undefined_primal(row)
- assert not ad.is_undefined_primal(col)
-
- if ad.is_undefined_primal(v):
- return data, row, col, _coomv_cusparse_p.bind(data, row, col, ct,
- shape=shape,
- rows_sorted=rows_sorted,
- cols_sorted=cols_sorted,
- transpose=not transpose)
- else:
- return ct[row] * v[col], row, col, v
-
-
-_coomv_cusparse_p = core.Primitive('cusparse_coo_matvec')
-_coomv_cusparse_p.def_abstract_eval(_coomv_abstract_eval)
-_coomv_cusparse_p.def_impl(_coomv_impl)
-ad.defjvp(_coomv_cusparse_p, _coomv_jvp_mat, None, None, _coomv_jvp_vec)
-ad.primitive_transposes[_coomv_cusparse_p] = _coomv_transpose
-mlir.register_lowering(_coomv_cusparse_p, _coo_matvec_lowering)
-mlir.register_lowering(_coomv_cusparse_p,
- partial(_coomv_gpu_lowering, gpu_sparse.cuda_coo_matvec),
- platform='cuda')
-register_general_batching(_coomv_cusparse_p)
-
-
+ if bti is None:
+ raise_braintaichi_not_found()
+
+ return bti.coomv(
+ data=data,
+ row=row,
+ col=col,
+ vector=vector,
+ shape=shape,
+ rows_sorted=rows_sorted,
+ cols_sorted=cols_sorted,
+ transpose=transpose,
+ method=method
+ )
diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py
new file mode 100644
index 000000000..4d5b0d6cd
--- /dev/null
+++ b/brainpy/_src/math/sparse/csr_mm.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+
+
+from typing import Union, Tuple
+
+from jax import numpy as jnp
+
+from brainpy._src.math.ndarray import Array
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
+
+bti = import_braintaichi(error_if_not_found=False)
+
+__all__ = [
+ 'csrmm',
+]
+
+
+def csrmm(
+ data: Union[float, jnp.ndarray, Array],
+ indices: Union[jnp.ndarray, Array],
+ indptr: Union[jnp.ndarray, Array],
+ matrix: Union[jnp.ndarray, Array],
+ *,
+ shape: Tuple[int, int],
+ transpose: bool = False,
+):
+ """
+ Product of CSR sparse matrix and a dense matrix.
+
+ Args:
+ data : array of shape ``(nse,)``.
+ indices : array of shape ``(nse,)``
+ indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
+ B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
+ dtype ``data.dtype``
+ shape : length-2 tuple representing the matrix shape
+ transpose : boolean specifying whether to transpose the sparse matrix
+ before computing.
+
+ Returns:
+ C : array of shape ``(shape[1] if transpose else shape[0], cols)``
+ representing the matrix-matrix product.
+ """
+ if bti is None:
+ raise_braintaichi_not_found()
+
+ return bti.csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
\ No newline at end of file
diff --git a/brainpy/_src/math/sparse/csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py
index 6eaf6b791..c39744bb4 100644
--- a/brainpy/_src/math/sparse/csr_mv.py
+++ b/brainpy/_src/math/sparse/csr_mv.py
@@ -3,20 +3,12 @@
from typing import Union, Tuple
-import jax
from jax import numpy as jnp
-from jax.experimental.sparse import csr
-from jax.interpreters import ad
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-from brainpy._src.math.interoperability import as_jax
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
from brainpy._src.math.ndarray import Array
-from brainpy._src.math.op_register import (register_general_batching, XLACustomOp)
-from brainpy._src.math.sparse.utils import csr_to_coo
-from brainpy.errors import PackageMissingError
-ti = import_taichi(error_if_not_found=False)
+bti = import_braintaichi(error_if_not_found=False)
__all__ = [
'csrmv',
@@ -69,257 +61,8 @@ def csrmv(
The array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
+ if bti is None:
+ raise_braintaichi_not_found()
- data = jnp.atleast_1d(as_jax(data))
- indices = as_jax(indices)
- indptr = as_jax(indptr)
- vector = as_jax(vector)
+ return bti.csrmv(data, indices, indptr, vector, shape=shape, transpose=transpose)
- if vector.dtype == jnp.bool_:
- vector = as_jax(vector, dtype=data.dtype)
-
- if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]:
- raise TypeError('Only support float16, float32 or float64 type. '
- f'But we got {data.dtype}.')
- if data.dtype != vector.dtype:
- raise TypeError('The types of data and vector should be the same. '
- f'But we got {data.dtype} != {vector.dtype}.')
- assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1
- if not jnp.issubdtype(indices.dtype, jnp.integer):
- raise ValueError('indices should be a 1D vector with integer type.')
- if not jnp.issubdtype(indptr.dtype, jnp.integer):
- raise ValueError('indptr should be a 1D vector with integer type.')
-
- # if the shape of indices is (0,), then we return a zero vector
- if indices.shape[0] == 0:
- return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)
-
- return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0]
-
-
-def raw_csrmv_taichi(
- data: Union[float, jnp.ndarray, Array],
- indices: Union[jnp.ndarray, Array],
- indptr: Union[jnp.ndarray, Array],
- vector: Union[jnp.ndarray, Array],
- *,
- shape: Tuple[int, int],
- transpose: bool = False,
-):
- if ti is None:
- raise PackageMissingError.by_purpose('taichi', purpose='customized operators')
-
- out_shape = shape[1] if transpose else shape[0]
- if data.shape[0] != 1:
- if bm.get_platform() == 'gpu':
- return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)]
- else:
- if transpose:
- prim = _csr_matvec_transpose_heter_p
- else:
- prim = _csr_matvec_heter_p
- else:
- if transpose:
- prim = _csr_matvec_transpose_homo_p
- else:
- prim = _csr_matvec_homo_p
-
- return prim(data,
- indices,
- indptr,
- vector,
- outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)],
- transpose=transpose,
- shape=shape)
-
-
-if ti is not None:
-
- # -------------
- # CPU operators
- # -------------
- @ti.kernel
- def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- ti.loop_config(serialize=True)
- for row_i in range(row_ptr.shape[0] - 1):
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- out[col_indices[j]] += value * vector[row_i]
-
-
- @ti.kernel
- def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- ti.loop_config(serialize=True)
- for row_i in range(row_ptr.shape[0] - 1):
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- out[col_indices[j]] += vector[row_i] * values[j]
-
-
- @ti.kernel
- def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- # ti.loop_config(serialize=True)
- for row_i in range(row_ptr.shape[0] - 1):
- r = 0.
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- r += vector[col_indices[j]]
- out[row_i] = r * value
-
-
- @ti.kernel
- def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- # ti.loop_config(serialize=True)
- for row_i in range(row_ptr.shape[0] - 1):
- r = 0.
- for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
- r += values[j] * vector[col_indices[j]]
- out[row_i] = r
-
-
- # -------------
- # GPU operators
- # -------------
-
- @ti.kernel
- def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- for i in range((row_ptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- j = row_ptr[row_i] + index
- end_index = row_ptr[row_i + 1]
- while j < end_index:
- out[col_indices[j]] += value * vector[row_i]
- j += 32
-
-
- @ti.kernel
- def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- value = values[0]
- for i in range((row_ptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- r = 0.
- j = row_ptr[row_i] + index
- end_index = row_ptr[row_i + 1]
- while j < end_index:
- r += vector[col_indices[j]]
- j += 32
- out[row_i] += value * r
-
-
- @ti.kernel
- def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- for i in range((row_ptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- j = row_ptr[row_i] + index
- end_index = row_ptr[row_i + 1]
- while j < end_index:
- out[col_indices[j]] += values[j] * vector[row_i]
- j += 32
-
-
- @ti.kernel
- def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
- col_indices: ti.types.ndarray(ndim=1),
- row_ptr: ti.types.ndarray(ndim=1),
- vector: ti.types.ndarray(ndim=1),
- out: ti.types.ndarray(ndim=1)):
- for i in range((row_ptr.shape[0] - 1) * 32):
- row_i = i >> 5
- index = i & 31
- r = 0.
- j = row_ptr[row_i] + index
- end_index = row_ptr[row_i + 1]
- while j < end_index:
- r += values[j] * vector[col_indices[j]]
- j += 32
- out[row_i] += r # TODO: warp-level primitive
-
-
- def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
- return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose)
-
-
- def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
- return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose)
-
-
- def _sparse_csr_matvec_transpose(
- ct, data, indices, indptr, vector, *, outs, transpose, shape,
- ):
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
- raise ValueError("Cannot transpose with respect to sparse indices.")
- if ad.is_undefined_primal(vector):
- ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0]
- return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector)
-
- else:
- if type(ct[0]) is ad.Zero:
- ct_data = ad.Zero(data)
- else:
- if data.aval.shape[0] == 1: # scalar
- ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0]
- ct_data = jnp.inner(ct[0], ct_data)
- else:
- row, col = csr_to_coo(indices, indptr)
- ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row]
-
- return ct_data, indices, indptr, vector
-
-
- def _define_op(cpu_kernel, gpu_kernel):
- prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
- prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector)
- prim.def_transpose_rule(_sparse_csr_matvec_transpose)
- return prim
-
-
- # transpose homo
- _csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu,
- gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu)
-
- # no transpose homo
- _csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu,
- gpu_kernel=_sparse_csr_matvec_homo_gpu)
-
- # transpose heter
- _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu,
- gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu)
-
- # no transpose heter
- _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu,
- gpu_kernel=_sparse_csr_matvec_heter_gpu)
-
- # heter cusparse
- _csr_matvec_cusparse_p = csr.csr_matvec_p
- register_general_batching(_csr_matvec_cusparse_p)
diff --git a/brainpy/_src/math/sparse/tests/csr_matvec_VS_cusparse_csr_matvec.py b/brainpy/_src/math/sparse/tests/csr_matvec_VS_cusparse_csr_matvec.py
deleted file mode 100644
index 8b4afe21e..000000000
--- a/brainpy/_src/math/sparse/tests/csr_matvec_VS_cusparse_csr_matvec.py
+++ /dev/null
@@ -1,668 +0,0 @@
-
-# -*- coding: utf-8 -*-
-
-import time
-
-import brainpy as bp
-import brainpy.math as bm
-import numpy as np
-
-from brainpy._src.math.sparse import cusparse_bcsr_matvec
-# from brainpy._src.math.sparse import cusparse_csr_matvec
-from brainpy._src.math.sparse import csrmv
-from scipy.sparse import csr_matrix
-
-def compare(platform='cpu'):
- """
-
- CPU
- ---
-
- shape = (1000, 1000)
- cuSPARSE 0.02663278579711914 s
- brainpylib 0.028490781784057617 s
-
- shape = (1000, 10000)
- cuSPARSE 0.06195855140686035 s
- brainpylib 0.04008936882019043 s
-
- shape = (10000, 1000)
- cuSPARSE 0.04706525802612305 s
- brainpylib 0.04366803169250488 s
-
- shape = (10000, 10000)
- cuSPARSE 0.1891341209411621 s
- brainpylib 0.177717924118042 s
-
- shape = (100000, 10000)
- cuSPARSE 1.3123579025268555 s
- brainpylib 1.3357517719268799 s
-
- shape = (100000, 100000)
- cuSPARSE 13.544525384902954 s
- brainpylib 14.612009048461914 s
-
-
- GPU
- ---
- shape = (1000, 1000)
- cuSPARSE 0.04015922546386719 s
- brainpylib 0.024152517318725586 s
-
- shape = (1000, 10000)
- cuSPARSE 0.04857826232910156 s
- brainpylib 0.15707015991210938 s
-
- shape = (10000, 1000)
- cuSPARSE 0.04973483085632324 s
- brainpylib 0.14293313026428223 s
-
- shape = (10000, 10000)
- cuSPARSE 0.17399168014526367 s
- brainpylib 0.17151856422424316 s
-
- shape = (100000, 10000)
- cuSPARSE 0.5249958038330078 s
- brainpylib 0.3427560329437256 s
-
- shape = (50000, 50000)
- cuSPARSE 1.4121572971343994 s
- brainpylib 0.9002335071563721 s
-
- shape = (100000, 50000)
- cuSPARSE 2.697688341140747 s
- brainpylib 1.6211459636688232 s
- """
-
-
- bm.set_platform(platform)
-
- for shape in [
- (1000, 1000),
- (1000, 10000),
- (10000, 1000),
- (10000, 10000),
- (100000, 10000),
- (50000, 50000),
- (100000, 50000),
- ]:
- print(f'shape = {shape}')
-
- rng = bm.random.RandomState(123)
- conn = bp.conn.FixedProb(0.1)(*shape)
- indices, indptr = conn.require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- data = rng.random(indices.shape).value
- vector = rng.random(shape[1]).value
-
- r1 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='cusparse')
- r1.block_until_ready()
- r2 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='vector')
- r2.block_until_ready()
-
- t0 = time.time()
- for _ in range(100):
- r1 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='cusparse')
- r1.block_until_ready()
- print(f'cuSPARSE {time.time() - t0} s')
-
- t0 = time.time()
- for _ in range(100):
- r1 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='vector')
- r1.block_until_ready()
- print(f'brainpylib {time.time() - t0} s')
- print()
-
-
-
-def compare2(platform='cpu'):
- """
-
- CPU
- ---
-
- shape = (1000, 1000)
- cuSPARSE 0.02663278579711914 s
- brainpylib 0.028490781784057617 s
-
- shape = (1000, 10000)
- cuSPARSE 0.06195855140686035 s
- brainpylib 0.04008936882019043 s
-
- shape = (10000, 1000)
- cuSPARSE 0.04706525802612305 s
- brainpylib 0.04366803169250488 s
-
- shape = (10000, 10000)
- cuSPARSE 0.1891341209411621 s
- brainpylib 0.177717924118042 s
-
- shape = (100000, 10000)
- cuSPARSE 1.3123579025268555 s
- brainpylib 1.3357517719268799 s
-
- shape = (100000, 100000)
- cuSPARSE 13.544525384902954 s
- brainpylib 14.612009048461914 s
-
-
- GPU
- ---
- shape = (1000, 1000)
- cuSPARSE 0.04015922546386719 s
- brainpylib 0.024152517318725586 s
-
- shape = (1000, 10000)
- cuSPARSE 0.04857826232910156 s
- brainpylib 0.15707015991210938 s
-
- shape = (10000, 1000)
- cuSPARSE 0.04973483085632324 s
- brainpylib 0.14293313026428223 s
-
- shape = (10000, 10000)
- cuSPARSE 0.17399168014526367 s
- brainpylib 0.17151856422424316 s
-
- shape = (100000, 10000)
- cuSPARSE 0.5249958038330078 s
- brainpylib 0.3427560329437256 s
-
- shape = (50000, 50000)
- cuSPARSE 1.4121572971343994 s
- brainpylib 0.9002335071563721 s
-
- shape = (100000, 50000)
- cuSPARSE 2.697688341140747 s
- brainpylib 1.6211459636688232 s
- """
-
- bm.set_platform(platform)
- p = 0.1
-
- for shape in [
- (1000, 1000),
- (1000, 10000),
- (10000, 1000),
- (10000, 10000),
- (100000, 10000),
- (50000, 50000),
- (100000, 50000),
- ]:
- print(f'shape = {shape}')
-
- rng = bm.random.RandomState()
- conn = bp.conn.FixedProb(p)(*shape)
- indices, indptr = conn.require('pre2post')
- data = rng.random(indices.shape)
- vector = rng.random(shape[1])
-
-
-
-
- bs_bsr = 16
- conn = bp.conn.FixedProb(p)(shape[0] // bs_bsr , shape[1] // bs_bsr)
- indices_bsr, indptr_bsr = conn.require('pre2post')
- data_bsr = rng.rand(len(indices_bsr)*bs_bsr, bs_bsr )
- shape_bsr = (shape[0] // bs_bsr, shape[1] // bs_bsr)
-
- # Mcsr = csr_matrix((data, indices, indptr), shape=shape)
- # Mbsr = Mcsr.tobsr(blocksize=(8,8))
- # bs_bsr = 8
- # indices_bsr = Mbsr.indices
- # indptr_bsr = Mbsr.indptr
- # data_bsr_2 = Mbsr.data
- # data_bsr = list(np.array(data_bsr_2).flatten())
- # indices_bsr = bm.as_jax(indices_bsr)
- # indptr_bsr = bm.as_jax(indptr_bsr)
- # data_bsr = bm.as_jax(data_bsr)
- # shape_bsr = (shape[0]//bs_bsr,shape[1]//bs_bsr)
-
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
-
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- r2 = csrmv(data, indices, indptr, vector, shape=shape)
- r2.block_until_ready()
-
- # print(r1[980:1000])
- # print(r2[980:1000])
- # print(r3[900:1000])
- # print(len(indptr_bsr))
- # print(shape_bsr)
-
- t0 = time.time()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr)
- r3.block_until_ready()
- print(f'bsrSPARSE {time.time() - t0} s')
-
- # t0 = time.time()
- # for _ in range(100):
- # r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape)
- # r3.block_until_ready()
- # print(f'bsrSPARSE {time.time() - t0} s')
-
-
- # t0 = time.time()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # print(f'cuSPARSE {time.time() - t0} s')
- # t0 = time.time()
- # for _ in range(100):
- # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape)
- # r1.block_until_ready()
- # print(f'cuSPARSE {time.time() - t0} s')
-
- t0 = time.time()
- for _ in range(100):
- r1 = csrmv(data, indices, indptr, vector, shape=shape)
- r1.block_until_ready()
- print(f'brainpylib {time.time() - t0} s')
- print()
-
- bm.clear_buffer_memory()
-
-
-if __name__ == '__main__':
- compare('cpu')
- # compare('gpu')
diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py
deleted file mode 100644
index 1db246212..000000000
--- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py
+++ /dev/null
@@ -1,250 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('cpu')
-
-s = [1000, 5000, 10000, 15000, 20000, 25000, 30000]
-p = [0.1, 0.2, 0.3, 0.4, 0.5]
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
-]
-
-values_type = [
- 'homo',
- 'heter'
- ]
-events_type = ['float']
-transpose = [
- True,
- False
- ]
-method = 'cusparse'
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-print(bm.get_platform())
-
-@partial(jax.jit, static_argnums=(4, 5))
-def csrmv_taichi(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0]
- return r
-
-@partial(jax.jit, static_argnums=(4, 5))
-def csrmv(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)
- return r
-
-def test_sparse_csrmv(shape, values_type, events_type, transpose):
- rng = bm.random.RandomState(seed=1234)
- indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
- vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
- weight = 1.
-
-
- if events_type == 'float':
- vector = vector.astype(bm.float32)
- if values_type == 'heter':
- heter_data = bm.ones(indices.shape) * weight
- weight = heter_data
-
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/csrmv_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/csrmv_gpu.csv', index=False)
diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py
deleted file mode 100644
index d902c9395..000000000
--- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# from jax_taichi import jax_taichi_call
-
-import time
-from functools import partial
-import os
-
-import brainpy as bp
-import brainpy.math as bm
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pandas as pd
-import taichi as ti
-
-bm.set_platform('cpu')
-
-s = [1000,
- 5000,
- 10000,
- 15000,
- 20000,
- 25000,
- 30000]
-p = [0.1, 0.2, 0.3, 0.4, 0.5]
-
-shape = [
- 1000,
- 2500,
- 5000,
- 10000,
- 25000,
- 37500,
- 50000
-]
-
-values_type = [
- 'homo',
- 'heter'
- ]
-events_type = ['float']
-transpose = [
- True,
- False
- ]
-method = 'cusparse'
-
-ITERATION = 100
-if bm.get_platform() == 'cpu':
- ITERATION = 10
-
-print(bm.get_platform())
-
-def sum_op(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)
- return r.sum()
-
- return func
-
-
-def sum_op2(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)[0]
- return r.sum()
-
- return func
-
-@partial(jax.jit, static_argnums=(4, 5))
-def csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)(
- weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- return r
-
-@partial(jax.jit, static_argnums=(4, 5))
-def csrmv_grad(weight, indices, indptr, vector, shape, transpose):
- r = 0
- for i in range(ITERATION):
- r += jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
- weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- return r
-
-def test_sparse_csrmv(shape, values_type, events_type, transpose):
- rng = bm.random.RandomState(seed=1234)
- indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
- vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
- weight = 1.
-
-
- if events_type == 'float':
- vector = vector.astype(bm.float32)
- if values_type == 'heter':
- heter_data = bm.ones(indices.shape) * weight
- weight = heter_data
-
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time0 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time1 = time.time()
-
- time2 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time3 = time.time()
-
- time4 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time5 = time.time()
-
- time6 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time7 = time.time()
-
- time8 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time9 = time.time()
-
- time10 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time11 = time.time()
-
- time12 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time13 = time.time()
-
- time14 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time15 = time.time()
-
- time16 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time17 = time.time()
-
- time18 = time.time()
- result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time19 = time.time()
-
-
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
-
- time20 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time21 = time.time()
-
- time22 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time23 = time.time()
-
- time24 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time25 = time.time()
-
- time26 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time27 = time.time()
-
- time28 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time29 = time.time()
-
- time30 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time31 = time.time()
-
- time32 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time33 = time.time()
-
- time34 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time35 = time.time()
-
- time36 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time37 = time.time()
-
- time38 = time.time()
- result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose))
- time39 = time.time()
-
- taichi_aot_time1 = (time1 - time0) * 1000
- taichi_aot_time2 = (time3 - time2) * 1000
- taichi_aot_time3 = (time5 - time4) * 1000
- taichi_aot_time4 = (time7 - time6) * 1000
- taichi_aot_time5 = (time9 - time8) * 1000
- taichi_aot_time6 = (time11 - time10) * 1000
- taichi_aot_time7 = (time13 - time12) * 1000
- taichi_aot_time8 = (time15 - time14) * 1000
- taichi_aot_time9 = (time17 - time16) * 1000
- taichi_aot_time10 = (time19 - time18) * 1000
- brainpy_time1 = (time21 - time20) * 1000
- brainpy_time2 = (time23 - time22) * 1000
- brainpy_time3 = (time25 - time24) * 1000
- brainpy_time4 = (time27 - time26) * 1000
- brainpy_time5 = (time29 - time28) * 1000
- brainpy_time6 = (time31 - time30) * 1000
- brainpy_time7 = (time33 - time32) * 1000
- brainpy_time8 = (time35 - time34) * 1000
- brainpy_time9 = (time37 - time36) * 1000
- brainpy_time10 = (time39 - time38) * 1000
- print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
- print('taichi_aot_1: ', taichi_aot_time1, 'ms')
- print('taichi_aot_3: ', taichi_aot_time3, 'ms')
- print('taichi_aot_5: ', taichi_aot_time5, 'ms')
- print('taichi_aot_7: ', taichi_aot_time7, 'ms')
- print('taichi_aot_9: ', taichi_aot_time9, 'ms')
- print('brainpylib_1: ', brainpy_time1, 'ms')
- print('brainpylib_3: ', brainpy_time3, 'ms')
- print('brainpylib_5: ', brainpy_time5, 'ms')
- print('brainpylib_7: ', brainpy_time7, 'ms')
- print('brainpylib_9: ', brainpy_time9, 'ms')
-
-
- return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
- taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\
- brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \
- brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10
-
-PATH = os.path.dirname(os.path.abspath(__file__))
-
-# init dataframe
-df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose',
- 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
- 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)',
- 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
- 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)'])
-
-
-### RECTANGULAR MATRIX
-if (bm.get_platform() == 'cpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/csrmv_grad_cpu.csv', index=False)
-
-if (bm.get_platform() == 'gpu'):
- for shape1 in shape:
- for shape2 in shape:
- for _values_type in values_type:
- for _events_type in events_type:
- for _transpose in transpose:
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose)
- # append to dataframe
- df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose,
- taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
- taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,
- brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5,
- brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]
- df.to_csv(f'{PATH}/csrmv_grad_gpu.csv', index=False)
diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py
deleted file mode 100644
index acedcff12..000000000
--- a/brainpy/_src/math/sparse/tests/test_csrmv.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from functools import partial
-
-import jax
-import pytest
-from absl.testing import parameterized
-
-import brainpy as bp
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
-
-import platform
-force_test = False # turn on to force test on windows locally
-if platform.system() == 'Windows' and not force_test:
- pytest.skip('skip windows', allow_module_level=True)
-
-
-seed = 1234
-
-
-def sum_op(op):
- def func(*args, **kwargs):
- r = op(*args, **kwargs)
- return r.sum()
-
- return func
-
-
-def compare_with_nan_tolerance(a, b, tol=1e-8):
- """
- Compare two arrays with tolerance for NaN values.
-
- Parameters:
- a (np.array): First array to compare.
- b (np.array): Second array to compare.
- tol (float): Tolerance for comparing non-NaN elements.
-
- Returns:
- bool: True if arrays are similar within the tolerance, False otherwise.
- """
- if a.shape != b.shape:
- return False
-
- # Create masks for NaNs in both arrays
- nan_mask_a = bm.isnan(a)
- nan_mask_b = bm.isnan(b)
-
- # Check if NaN positions are the same in both arrays
- if not bm.array_equal(nan_mask_a, nan_mask_b):
- return False
-
- # Compare non-NaN elements
- a_non_nan = a[~nan_mask_a]
- b_non_nan = b[~nan_mask_b]
-
- return bm.allclose(a_non_nan, b_non_nan, atol=tol)
-
-
-class Test_csrmv_taichi(parameterized.TestCase):
- def __init__(self, *args, platform='cpu', **kwargs):
- super(Test_csrmv_taichi, self).__init__(*args, **kwargs)
-
- print()
- bm.set_platform(platform)
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(200, 200), (10, 1000)],
- homo_data=[1.]
- )
- def test_homo(self, transpose, shape, homo_data):
- print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}')
- conn = bp.conn.FixedProb(0.3)
-
- # matrix
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- # vector
- rng = bm.random.RandomState(seed=seed)
- vector = rng.random(shape[0] if transpose else shape[1])
- vector = bm.as_jax(vector)
-
- heter_data = bm.ones(indices.shape).value * homo_data
-
- dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r1 = (vector @ dense) if transpose else (dense @ vector)
- r2 = bm.sparse.csrmv(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r1, r2))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(200, 200), (100, 1000)],
- v=[1.]
- )
- def test_homo_vmap(self, transpose, shape, v):
- print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}')
- rng = bm.random.RandomState(seed=seed)
- conn = bp.conn.FixedProb(0.3)
-
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- vector = rng.random(shape[0] if transpose else shape[1])
- vector = bm.as_jax(vector)
-
- heter_data = bm.ones((10, indices.shape[0])).value * v
- homo_data = bm.ones(10).value * v
- dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data)
-
- f1 = lambda a: (a.T @ vector) if transpose else (a @ vector)
- f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector,
- shape=shape, transpose=transpose)
- r1 = jax.vmap(f1)(dense_data)
- r2 = jax.vmap(f2)(homo_data)
- self.assertTrue(bm.allclose(r1, r2))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(200, 200), (10, 1000)],
- homo_data=[1.]
- )
- def test_homo_grad(self, transpose, shape, homo_data):
- print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}')
- rng = bm.random.RandomState(seed=seed)
- conn = bp.conn.FixedProb(0.3)
-
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value,
- indices,
- indptr,
- shape=shape)
- vector = rng.random(shape[0] if transpose else shape[1])
- vector = bm.as_jax(vector)
-
- # print('grad data start')
- # grad 'data'
- dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum()
- if transpose else
- ((dense * a) @ vector).sum()),
- argnums=0)
- r1 = dense_f1(homo_data)
- r2 = jax.grad(sum_op(bm.sparse.csrmv))(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose)
-
- self.assertTrue(bm.allclose(r1, r2))
-
- # print('grad vector start')
- # grad 'vector'
- dense_data = dense * homo_data
- dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()))
- r3 = dense_f2(vector)
- r4 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
- bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
-
- self.assertTrue(bm.allclose(r3, r4))
-
- dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum()
- if transpose else
- ((dense * a) @ v).sum()),
- argnums=(0, 1))
- r5 = dense_f3(homo_data, vector)
- r6 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))(
- bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
- self.assertTrue(bm.allclose(r5[0], r6[0]))
- self.assertTrue(bm.allclose(r5[1], r6[1]))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(200, 200), (2, 2000)],
- )
- def test_heter(self, transpose, shape):
- print(f'test_homo: transpose = {transpose} shape = {shape}')
- rng = bm.random.RandomState(seed=seed)
- conn = bp.conn.FixedProb(0.3)
-
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
-
- heter_data = bm.as_jax(rng.random(indices.shape))
- heter_data = bm.as_jax(heter_data)
-
- vector = rng.random(shape[0] if transpose else shape[1])
- vector = bm.as_jax(vector)
-
- dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- r1 = (vector @ dense) if transpose else (dense @ vector)
- r2 = bm.sparse.csrmv(heter_data, indices, indptr, vector, shape=shape, transpose=transpose)
-
- self.assertTrue(compare_with_nan_tolerance(r1, r2))
-
- bm.clear_buffer_memory()
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(200, 200), (2, 2000)]
- )
- def test_heter_vmap(self, transpose, shape):
- rng = bm.random.RandomState(seed=seed)
- conn = bp.conn.FixedProb(0.3)
-
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- vector = rng.random(shape[0] if transpose else shape[1])
- vector = bm.as_jax(vector)
-
- heter_data = rng.random((10, indices.shape[0]))
- heter_data = bm.as_jax(heter_data)
- dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr,
- shape=shape))(heter_data)
-
- f1 = lambda a: (a.T @ vector) if transpose else (a @ vector)
- f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector,
- shape=shape, transpose=transpose)
- r1 = jax.vmap(f1)(dense_data)
- r2 = jax.vmap(f2)(heter_data)
- self.assertTrue(compare_with_nan_tolerance(r1, r2))
-
- @parameterized.product(
- transpose=[True, False],
- shape=[(200, 200), (2, 2000)]
- )
- def test_heter_grad(self, transpose, shape):
- rng = bm.random.RandomState(seed=seed)
- conn = bp.conn.FixedProb(0.3)
-
- indices, indptr = conn(*shape).require('pre2post')
- indices = bm.as_jax(indices)
- indptr = bm.as_jax(indptr)
- heter_data = rng.random(indices.shape)
- heter_data = bm.as_jax(heter_data)
- dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
- vector = rng.random(shape[0] if transpose else shape[1])
- vector = bm.as_jax(vector)
-
- # grad 'data'
- dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0)
- csr_f1 = jax.grad(lambda a: bm.sparse.csrmv(a, indices, indptr, vector,
- shape=shape,
- transpose=transpose).sum(),
- argnums=0)
- r1 = csr_f1(heter_data)
- r2 = dense_f1(dense_data)
- rows, cols = bm.sparse.csr_to_coo(indices, indptr)
- r2 = r2[rows, cols]
- print(r1.shape, r2.shape)
- self.assertTrue(bm.allclose(r1, r2))
-
- # grad 'vector'
- dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0)
- csr_f2 = jax.grad(lambda v: bm.sparse.csrmv(heter_data, indices, indptr, v,
- shape=shape,
- transpose=transpose).sum(),
- argnums=0)
- r3 = dense_f2(vector)
- r4 = csr_f2(vector)
- self.assertTrue(bm.allclose(r3, r4))
-
- bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/sparse/utils.py b/brainpy/_src/math/sparse/utils.py
index f5b74e5eb..38cfdb7b9 100644
--- a/brainpy/_src/math/sparse/utils.py
+++ b/brainpy/_src/math/sparse/utils.py
@@ -1,22 +1,46 @@
# -*- coding: utf-8 -*-
import warnings
+from functools import partial
from typing import Tuple
import numpy as np
+from brainpy._src.math.interoperability import as_jax
from jax import core, numpy as jnp
+from jax import lax
+from jax.interpreters import batching
from jax.interpreters import mlir, ad
+from jax.tree_util import tree_flatten, tree_unflatten
from jaxlib import gpu_sparse
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import register_general_batching
-
__all__ = [
'coo_to_csr',
'csr_to_coo',
'csr_to_dense'
]
+def _general_batching_rule(prim, args, axes, **kwargs):
+ batch_axes, batch_args, non_batch_args = [], {}, {}
+ for ax_i, ax in enumerate(axes):
+ if ax is None:
+ non_batch_args[f'ax{ax_i}'] = args[ax_i]
+ else:
+ batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0)
+ batch_axes.append(ax_i)
+
+ def f(_, x):
+ pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
+ for i in range(len(axes))])
+ return 0, prim.bind(*pars, **kwargs)
+
+ _, outs = lax.scan(f, 0, batch_args)
+ out_vals, out_tree = tree_flatten(outs)
+ out_dim = tree_unflatten(out_tree, (0,) * len(out_vals))
+ return outs, out_dim
+
+def _register_general_batching(prim):
+ batching.primitive_batchers[prim] = partial(_general_batching_rule, prim)
+
def coo_to_csr(
pre_ids: jnp.ndarray,
@@ -153,6 +177,6 @@ def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape):
ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None)
ad.primitive_transposes[csr_to_dense_p] = _csr_to_dense_transpose
mlir.register_lowering(csr_to_dense_p, _csr_to_dense_lowering)
-register_general_batching(csr_to_dense_p)
+_register_general_batching(csr_to_dense_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda')
diff --git a/brainpy/_src/math/surrogate/__init__.py b/brainpy/_src/math/surrogate/__init__.py
index 199eac648..f88816d70 100644
--- a/brainpy/_src/math/surrogate/__init__.py
+++ b/brainpy/_src/math/surrogate/__init__.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-from .base import *
from ._one_input_new import *
from ._two_inputs import *
diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py
index 59b4ab09b..9dcf8c756 100644
--- a/brainpy/_src/math/surrogate/_one_input.py
+++ b/brainpy/_src/math/surrogate/_one_input.py
@@ -8,7 +8,6 @@
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
-from .base import Surrogate
__all__ = [
'sigmoid',
@@ -32,6 +31,16 @@
]
+class Surrogate(object):
+ """The base surrograte gradient function."""
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
+
+
class _OneInpSurrogate(Surrogate):
def __init__(self, forward_use_surrogate=False):
self.forward_use_surrogate = forward_use_surrogate
diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py
index bfffd88f5..ed9957261 100644
--- a/brainpy/_src/math/surrogate/_one_input_new.py
+++ b/brainpy/_src/math/surrogate/_one_input_new.py
@@ -12,6 +12,7 @@
from brainpy._src.math.ndarray import Array
__all__ = [
+ 'Surrogate',
'Sigmoid',
'sigmoid',
'PiecewiseQuadratic',
@@ -61,7 +62,7 @@ def _heaviside_imp(x, dx):
def _heaviside_batching(args, axes):
- return heaviside_p.bind(*args), axes
+ return heaviside_p.bind(*args), [axes[0]]
def _heaviside_jvp(primals, tangents):
diff --git a/brainpy/_src/math/surrogate/base.py b/brainpy/_src/math/surrogate/base.py
deleted file mode 100644
index dceb58b5c..000000000
--- a/brainpy/_src/math/surrogate/base.py
+++ /dev/null
@@ -1,19 +0,0 @@
-
-
-__all__ = [
- 'Surrogate'
-]
-
-
-class Surrogate(object):
- """The base surrograte gradient function."""
- def __call__(self, *args, **kwargs):
- raise NotImplementedError
-
- def __repr__(self):
- return f'{self.__class__.__name__}()'
-
-
-
-
-
diff --git a/brainpy/_src/math/tests/test_tifunc.py b/brainpy/_src/math/tests/test_tifunc.py
index db6e7debc..5bf0a0ad5 100644
--- a/brainpy/_src/math/tests/test_tifunc.py
+++ b/brainpy/_src/math/tests/test_tifunc.py
@@ -9,11 +9,6 @@
import matplotlib.pyplot as plt
import os
-from brainpy._src.dependency_check import import_taichi
-
-ti = import_taichi(error_if_not_found=False)
-if ti is None:
- pytest.skip('no taichi', allow_module_level=True)
bm.set_platform('cpu')
diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py
deleted file mode 100644
index 9cfd39e1a..000000000
--- a/brainpy/_src/math/tifunc.py
+++ /dev/null
@@ -1,345 +0,0 @@
-from brainpy._src.dependency_check import import_taichi, raise_taichi_not_found
-from . import defaults
-
-ti = import_taichi(error_if_not_found=False)
-
-__all__ = [
- # taichi function for other utilities
- 'warp_reduce_sum',
-
- # taichi functions for random number generator with LFSR88 algorithm
- 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn',
- 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand',
-
- # taichi functions for random number generator with LFSR113 algorithm
- 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn',
- 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand',
-]
-
-if ti is not None:
-
- #############################################
- # Random Number Generator: LFSR88 algorithm #
- #############################################
-
- @ti.func
- def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32):
- """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer).
-
- This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``.
-
- Source:
- https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c
-
- /**** VERY IMPORTANT **** :
- The initial seeds s1, s2, s3 MUST be larger than
- 1, 7, and 15 respectively.
- */
-
- Args:
- seed: int. The seed value for the random number generator.
-
- Returns:
- ti.math.uvec4: The random key for the LFSR88 random number generator.
- """
- return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0))
-
-
- @ti.func
- def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32):
- """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer).
-
- Args:
- key: The state value for the random number generator.
-
- Returns:
- ti.math.uvec4: The next random key.
- """
- b = ti.u32(((key[0] << 13) ^ key[0]) >> 19)
- s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b
- b = ((key[1] << 2) ^ key[1]) >> 25
- s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b
- b = ((key[2] << 3) ^ key[2]) >> 11
- s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b
- return ti.math.uvec4(s1, s2, s3, b)
-
-
- @ti.func
- def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10):
- """
- Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm.
-
- Args:
- key: The state value for the random number generator.
- mu: The mean of the normal distribution.
- sigma: The standard deviation of the normal distribution.
- epsilon: The epsilon value to avoid log(0).
- """
-
- key, r = lfsr88_randn(key, epsilon)
- return key, mu + sigma * r
-
-
- @ti.func
- def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10):
- """
- Generate a random number with the standard normal distribution using the LFSR88 algorithm.
-
- Args:
- key: The state value for the random number generator.
- epsilon: The epsilon value to avoid log(0).
-
- References:
- Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
- Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method
-
- """
-
- key, u1 = lfsr88_rand(key)
- key, u2 = lfsr88_rand(key)
-
- # Ensure state1 is not zero to avoid log(0)
- u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float)
-
- # Normalize the uniform samples
- mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float)
-
- # Box-Muller transform
- # z1 = mag * ti.cos(2 * ti.math.pi * u2)
- z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float)
-
- return key, z2
-
-
- @ti.func
- def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high):
- """
- Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm.
-
- Parameters:
- key: The state value used for random number generation.
- low: The lower bound of the range.
- high: The upper bound of the range.
- """
- key = lfsr88_next_key(key)
- return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int)
-
-
- @ti.func
- def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32):
- key = lfsr88_next_key(key)
- return key, dtype(key[0] ^ key[1] ^ key[2])
-
-
- @ti.func
- def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high):
- """
- Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm.
-
- Args:
- key: The state value used for random number generation.
- low: The lower bound of the range.
- high: The upper bound of the range.
- """
- key = lfsr88_next_key(key)
- r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
- return key, ti.cast(r * (high - low) + low, defaults.ti_float)
-
-
- @ti.func
- def lfsr88_rand(key: ti.types.vector(4, ti.u32)):
- """
- Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm.
-
- Args:
- key: The state value used for random number generation.
- """
- key = lfsr88_next_key(key)
- return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
-
-
- ##############################################
- # Random Number Generator: LFSR113 algorithm #
- ##############################################
-
- @ti.func
- def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32):
- """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer).
-
- This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``.
-
- Source:
- https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c
-
- /**** VERY IMPORTANT **** :
- The initial seeds s1, s2, s3, s4 MUST be larger than
- 1, 7, 15, and 127 respectively.
- */
-
- Args:
- seed: int. The seed value for the random number generator.
-
- Returns:
- ti.math.uvec4: The random key for the LFSR113 random number generator.
- """
- return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127))
-
-
- @ti.func
- def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32):
- """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer).
-
- Args:
- key: The state value for the random number generator.
-
- Returns:
- ti.math.uvec4: The next random key.
- """
- z1 = key[0]
- z2 = key[1]
- z3 = key[2]
- z4 = key[3]
- b = ((z1 << 6) ^ z1) >> 13
- z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b)
- b = ((z2 << 2) ^ z2) >> 27
- z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b)
- b = ((z3 << 13) ^ z3) >> 21
- z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b)
- b = ((z4 << 3) ^ z4) >> 12
- z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b)
- return ti.math.uvec4(z1, z2, z3, z4)
-
-
- @ti.func
- def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10):
- """
- Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm.
-
- Args:
- key: The state value for the random number generator.
- mu: The mean of the normal distribution.
- sigma: The standard deviation of the normal distribution.
- epsilon: The epsilon value to avoid log(0).
- """
-
- key, r = lfsr113_randn(key, epsilon)
- return key, ti.cast(mu + sigma * r, defaults.ti_float)
-
-
- @ti.func
- def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10):
- """
- Generate a random number with standard normal distribution using the LFSR113 algorithm.
-
- Args:
- key: The state value for the random number generator.
- epsilon: The epsilon value to avoid log(0).
-
- References:
- Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
- Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method
-
- """
-
- key, u1 = lfsr113_rand(key)
- key, u2 = lfsr113_rand(key)
-
- # Ensure state1 is not zero to avoid log(0)
- u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float)
-
- # Normalize the uniform samples
- mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float)
-
- # Box-Muller transform
- # z1 = mag * ti.cos(2 * ti.math.pi * u2)
- z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float)
-
- return key, z2
-
-
- @ti.func
- def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high):
- """
- Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm.
-
- Parameters:
- key: The state value used for random number generation.
- low: The lower bound of the range.
- high: The upper bound of the range.
- """
- key = lfsr113_next_key(key)
- return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int)
-
-
- @ti.func
- def lfsr113_randint(key: ti.types.vector(4, ti.u32)):
- key = lfsr113_next_key(key)
- return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int)
-
-
- @ti.func
- def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high):
- """
- Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm.
-
- Args:
- key: The state value used for random number generation.
- low: The lower bound of the range.
- high: The upper bound of the range.
- """
- key = lfsr88_next_key(key)
- r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
- return key, ti.cast(r * (high - low) + low, defaults.ti_float)
-
-
- @ti.func
- def lfsr113_rand(key: ti.types.vector(4, ti.u32)):
- """
- Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm.
-
- Args:
- key: The state value used for random number generation.
- """
- key = lfsr113_next_key(key)
- return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float)
-
-
- ###########################
- # Reductions: warp reduce #
- ###########################
-
- @ti.func
- def warp_reduce_sum_all(val):
- """
- Warp reduce sum.
-
- Args:
- val (float): The value to be reduced.
-
- Returns:
- float: The reduced value.
- """
- for i in ti.static(range(1, 32)):
- val += ti.static(ti.simt.warp.shfl_xor(val, i))
- return val
-
-
- @ti.func
- def warp_reduce_sum(val):
- """
- Warp reduce sum.
-
- Args:
- val (float): The value to be reduced.
-
- Returns:
- float: The reduced value.
- """
- for offset in ti.static((16, 8, 4, 2, 1)):
- val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset)
- return val
-
-
-else:
- for func in __all__:
- globals()[func] = raise_taichi_not_found
\ No newline at end of file
diff --git a/brainpy/_src/optimizers/tests/test_ModifyLr.py b/brainpy/_src/optimizers/tests/test_ModifyLr.py
index 01e51016e..67e1f6378 100644
--- a/brainpy/_src/optimizers/tests/test_ModifyLr.py
+++ b/brainpy/_src/optimizers/tests/test_ModifyLr.py
@@ -28,7 +28,7 @@ def train_data():
class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
- self.rnn = bp.dnn.RNNCell(num_in, num_hidden, train_state=True)
+ self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)
def update(self, x):
diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py
index 980ef9986..73cee5081 100644
--- a/brainpy/_src/runners.py
+++ b/brainpy/_src/runners.py
@@ -11,7 +11,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
-from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_map, tree_flatten
from brainpy import math as bm, tools
@@ -632,12 +631,17 @@ def _step_func_predict(self, i, *x, shared_args=None):
# finally
if self.progress_bar:
- id_tap(lambda *arg: self._pbar.update(), ())
+ jax.debug.callback(lambda *args: self._pbar.update(), ())
# share.clear_shargs()
clear_input(self.target)
if self._memory_efficient:
- id_tap(self._step_mon_on_cpu, mon)
+ mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype)
+ result = jax.pure_callback(
+ self._step_mon_on_cpu,
+ mon_shape_dtype,
+ mon,
+ )
return out, None
else:
return out, mon
diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py
index 6f2411ee8..037f283ac 100644
--- a/brainpy/_src/tests/test_dyn_runner.py
+++ b/brainpy/_src/tests/test_dyn_runner.py
@@ -5,11 +5,6 @@
import brainpy as bp
import brainpy.math as bm
-from brainpy._src.dependency_check import import_taichi
-
-if import_taichi(error_if_not_found=False) is None:
- pytest.skip('no taichi', allow_module_level=True)
-
class TestDSRunner(unittest.TestCase):
def test1(self):
diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py
index 2bfa419d6..36ed3c2b9 100644
--- a/brainpy/_src/train/offline.py
+++ b/brainpy/_src/train/offline.py
@@ -2,9 +2,9 @@
from typing import Dict, Sequence, Union, Callable, Any
+import jax
import numpy as np
import tqdm.auto
-from jax.experimental.host_callback import id_tap
import brainpy.math as bm
from brainpy import tools
@@ -219,7 +219,7 @@ def _fun_train(self,
targets = target_data[node.name]
node.offline_fit(targets, fit_record)
if self.progress_bar:
- id_tap(lambda *args: self._pbar.update(), ())
+ jax.debug.callback(lambda *args: self._pbar.update(), ())
def _step_func_monitor(self):
res = dict()
diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py
index d80764f26..d8e185c30 100644
--- a/brainpy/_src/train/online.py
+++ b/brainpy/_src/train/online.py
@@ -2,9 +2,9 @@
import functools
from typing import Dict, Sequence, Union, Callable
+import jax
import numpy as np
import tqdm.auto
-from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_map
from brainpy import math as bm, tools
@@ -252,7 +252,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):
# finally
if self.progress_bar:
- id_tap(lambda *arg: self._pbar.update(), ())
+ jax.debug.callback(lambda *args: self._pbar.update(), ())
return out, monitors
def _check_interface(self):
diff --git a/brainpy/check.py b/brainpy/check.py
index fafc0551d..1f809d840 100644
--- a/brainpy/check.py
+++ b/brainpy/check.py
@@ -7,7 +7,6 @@
import numpy as np
import numpy as onp
from jax import numpy as jnp
-from jax.experimental.host_callback import id_tap
from jax.lax import cond
conn = None
@@ -570,7 +569,11 @@ def is_all_objs(targets: Any, out_as: str = 'tuple'):
def _err_jit_true_branch(err_fun, x):
- id_tap(err_fun, x)
+ if isinstance(x, (tuple, list)):
+ x_shape_dtype = tuple(jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in x)
+ else:
+ x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype)
+ jax.pure_callback(err_fun, x_shape_dtype, x)
return
@@ -629,6 +632,6 @@ def true_err_fun(arg, transforms):
raise err
cond(remove_vmap(as_jax(pred)),
- lambda: id_tap(true_err_fun, None),
+ lambda: jax.pure_callback(true_err_fun, None),
lambda: None)
diff --git a/brainpy/integrators/__init__.py b/brainpy/integrators/__init__.py
index 176a71aec..7696bd33a 100644
--- a/brainpy/integrators/__init__.py
+++ b/brainpy/integrators/__init__.py
@@ -3,4 +3,5 @@
from . import ode
from . import sde
from . import fde
+from brainpy._src.integrators.base import compile_integrators
from brainpy._src.integrators.constants import *
diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py
index 08a070f02..624ade1b7 100644
--- a/brainpy/math/__init__.py
+++ b/brainpy/math/__init__.py
@@ -16,7 +16,6 @@
# operators
from .pre_syn_post import *
-from .op_register import *
from . import surrogate, event, sparse, jitconn
# Variable and Objects for object-oriented JAX transformations
@@ -33,8 +32,6 @@
from . import linalg
from . import random
-# taichi operations
-from . import tifunc
# others
from . import sharding
@@ -46,10 +43,7 @@
from brainpy._src.math import defaults
from brainpy._src.deprecations import deprecation_getattr
-from brainpy._src.dependency_check import import_taichi, import_numba
-import_taichi(error_if_not_found=False)
-import_numba(error_if_not_found=False)
__deprecations = {
"sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.",
diff --git a/brainpy/math/compat_numpy.py b/brainpy/math/compat_numpy.py
index ad6c8184f..2d068acbd 100644
--- a/brainpy/math/compat_numpy.py
+++ b/brainpy/math/compat_numpy.py
@@ -327,20 +327,14 @@
sort_complex as sort_complex,
unpackbits as unpackbits,
delete as delete,
- add_docstring as add_docstring,
- add_newdoc as add_newdoc,
- add_newdoc_ufunc as add_newdoc_ufunc,
- array2string as array2string,
asanyarray as asanyarray,
ascontiguousarray as ascontiguousarray,
asfarray as asfarray,
asscalar as asscalar,
common_type as common_type,
- disp as disp,
genfromtxt as genfromtxt,
loadtxt as loadtxt,
info as info,
- issubclass_ as issubclass_,
place as place,
polydiv as polydiv,
put as put,
diff --git a/brainpy/math/event.py b/brainpy/math/event.py
index 02e98b8f3..3b4b5ed1e 100644
--- a/brainpy/math/event.py
+++ b/brainpy/math/event.py
@@ -1,3 +1,4 @@
from brainpy._src.math.event import (
csrmv as csrmv,
+ csrmm as csrmm,
)
diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py
index a87d27d58..3c99b7de7 100644
--- a/brainpy/math/jitconn.py
+++ b/brainpy/math/jitconn.py
@@ -6,5 +6,9 @@
mv_prob_homo as mv_prob_homo,
mv_prob_uniform as mv_prob_uniform,
mv_prob_normal as mv_prob_normal,
+
+ get_homo_weight_matrix as get_homo_weight_matrix,
+ get_uniform_weight_matrix as get_uniform_weight_matrix,
+ get_normal_weight_matrix as get_normal_weight_matrix,
)
diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py
deleted file mode 100644
index f383c1a20..000000000
--- a/brainpy/math/op_register.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# -*- coding: utf-8 -*-
-from brainpy._src.math.op_register import (
- CustomOpByNumba,
- compile_cpu_signature_with_numba,
- clear_taichi_aot_caches,
- count_taichi_aot_kernels,
-)
-
-from brainpy._src.math.op_register.base import XLACustomOp
-from brainpy._src.math.op_register.ad_support import defjvp
-
-
diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py
index aa86679ec..8a209901f 100644
--- a/brainpy/math/sparse.py
+++ b/brainpy/math/sparse.py
@@ -4,6 +4,9 @@
)
from brainpy._src.math.sparse import (
csrmv,
+ csrmm,
+
+ seg_matmul,
csr_to_dense as csr_to_dense,
csr_to_coo as csr_to_coo,
diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py
deleted file mode 100644
index bea49c220..000000000
--- a/brainpy/math/tifunc.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from brainpy._src.math.tifunc import (
-
- # warp reduction primitives
- warp_reduce_sum,
-
- # random number generator
- lfsr88_key,
- lfsr88_next_key,
- lfsr88_normal,
- lfsr88_randn,
- lfsr88_random_integers,
- lfsr88_randint,
- lfsr88_uniform,
- lfsr88_rand,
- lfsr113_key,
- lfsr113_next_key,
- lfsr113_normal,
- lfsr113_randn,
- lfsr113_random_integers,
- lfsr113_randint,
- lfsr113_uniform,
- lfsr113_rand
-)
diff --git a/docs/apis/connect.rst b/docs/apis/connect.rst
index 9c42fbabb..759f67d1d 100644
--- a/docs/apis/connect.rst
+++ b/docs/apis/connect.rst
@@ -27,22 +27,9 @@ Base Connection Classes and Tools
coo2csr
coo2csc
coo2mat
- coo2mat_num
- mat2mat_num
- visualizeMat
- MAT_DTYPE
- IDX_DTYPE
Connector
TwoEndConnector
OneEndConnector
- CONN_MAT
- PRE_IDS
- POST_IDS
- PRE2POST
- POST2PRE
- PRE2SYN
- POST2SYN
- SUPPORTED_SYN_STRUCTURE
Custom Connections
diff --git a/docs/index.rst b/docs/index.rst
index d4d4f2721..29f7a901e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -110,6 +110,20 @@ Learn more
APIs may be changed over time. Please always keeps
in mind what BrainPy version you are using.
+.. note::
+ Starting from our experimental BrainPy package, a better and mature ecosystem for brain dynamics programming is emerging.
+ Please see the `Brain Dynamics Programming Ecosystem `_ for more details.
+
+ If you are heavily using BrainPy, please consider using `brainstate `_ for a more stable, efficient, concise, and powerful experience.
+
+ `brainstate `_ is and will be active maintained and developed by our team.
+ We highly recommend transferring your code to brainstate for a better performance.
+
+
+
+----
+
+
.. toctree::
:hidden:
diff --git a/docs/quickstart/analysis.ipynb b/docs/quickstart/analysis.ipynb
index d8b62de11..49a7d38d4 100644
--- a/docs/quickstart/analysis.ipynb
+++ b/docs/quickstart/analysis.ipynb
@@ -78,7 +78,8 @@
],
"source": [
"bp.__version__"
- ]
+ ],
+ "id": "b9f06f10049d4a8c"
},
{
"cell_type": "markdown",
@@ -108,7 +109,8 @@
"\\tau {\\dot {V}}= - (V - V_\\mathrm{rest}) + \\Delta_T \\exp(\\frac{V - V_T}{\\Delta_T}) + RI \\\\\n",
"\\mathrm{if}\\, \\, V > \\theta, \\quad V \\gets V_\\mathrm{reset}\n",
"$$"
- ]
+ ],
+ "id": "1ecf04d3eaf72477"
},
{
"cell_type": "markdown",
@@ -355,7 +357,8 @@
},
"source": [
"## Slow point analysis of a high-dimensional system"
- ]
+ ],
+ "id": "f8c9850a60c8a2b0"
},
{
"cell_type": "markdown",
@@ -366,7 +369,8 @@
"BrainPy is also capable of performing fixed/slow point analysis of high-dimensional systems. Moreover, it can perform automatic linearization analysis around the fixed point.\n",
"\n",
"In the following, we use a gap junction coupled FitzHugh–Nagumo (FHN) network as an example to demonstrate how to find fixed/slow points of a high-dimensional system."
- ]
+ ],
+ "id": "a42b23306515127c"
},
{
"cell_type": "markdown",
@@ -375,7 +379,8 @@
},
"source": [
"We first define the gap junction coupled FHN network as the normal ``DynamicalSystem`` class."
- ]
+ ],
+ "id": "24471089c2a5247d"
},
{
"cell_type": "code",
@@ -424,7 +429,8 @@
" self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)\n",
" self.w.value = self.int_w(self.w, t, self.V, dt)\n",
" self.Iext[:] = 0."
- ]
+ ],
+ "id": "f9fa252ff8617c22"
},
{
"cell_type": "markdown",
@@ -433,7 +439,8 @@
},
"source": [
"Through simulation, we can easily find that this system has a limit cycle attractor, implying that an unstable fixed point exists."
- ]
+ ],
+ "id": "158d89a270070b5d"
},
{
"cell_type": "code",
@@ -485,7 +492,8 @@
"bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V',\n",
" plot_ids=list(range(model.num)),\n",
" show=True)"
- ]
+ ],
+ "id": "58e6a900ff39ff43"
},
{
"cell_type": "markdown",
@@ -494,7 +502,8 @@
},
"source": [
"Let's try to optimize the fixed points for this system. Note that we only take care of the variables ``V`` and ``w``. Different from the low-dimensional analyzer, we should provide the candidate fixed points or initial fixed points when using the high-dimensional analyzer."
- ]
+ ],
+ "id": "22084ff723d5331"
},
{
"cell_type": "code",
@@ -582,7 +591,8 @@
"\n",
"# remove the duplicate fixed points\n",
"finder.keep_unique()"
- ]
+ ],
+ "id": "90ce07cf2a0a8e3c"
},
{
"cell_type": "code",
@@ -617,7 +627,8 @@
"source": [
"print('fixed points:', )\n",
"finder.fixed_points"
- ]
+ ],
+ "id": "45b81a4cdac07efa"
},
{
"cell_type": "code",
@@ -651,7 +662,8 @@
"source": [
"print('fixed point losses:', )\n",
"finder.losses"
- ]
+ ],
+ "id": "adc5767da75ea8b2"
},
{
"cell_type": "markdown",
@@ -660,7 +672,8 @@
},
"source": [
"Let's perform the linearization analysis of the found fixed points, and visualize its decomposition results."
- ]
+ ],
+ "id": "511b2ac23cb0c586"
},
{
"cell_type": "code",
@@ -694,7 +707,8 @@
],
"source": [
"_ = finder.compute_jacobians(finder.fixed_points, plot=True)"
- ]
+ ],
+ "id": "49e773e8700380ce"
},
{
"cell_type": "markdown",
@@ -703,7 +717,8 @@
},
"source": [
"This is an unstable fixed point, because one of its eigenvalues has the real part bigger than 1."
- ]
+ ],
+ "id": "a1a7df686a3bb68f"
},
{
"cell_type": "markdown",
@@ -712,7 +727,8 @@
},
"source": [
"## Further reading"
- ]
+ ],
+ "id": "b9e1a94399ffa6a1"
},
{
"cell_type": "markdown",
@@ -723,7 +739,8 @@
"- For more details about how to perform bifurcation analysis and phase plane analysis, please see the tutorial of [Low-dimensional Analyzers](../tutorial_analysis/lowdim_analysis.ipynb).\n",
"- A good example of phase plane analysis and bifurcation analysis is the decision-making model, please see the tutorial in [Analysis of a Decision-making Model](../tutorial_analysis/decision_making_model.ipynb)\n",
"- If you want to how to analyze the slow points (or fixed points) of your high-dimensional dynamical models, please see the tutorial of [High-dimensional Analyzers](../tutorial_analysis/highdim_analysis.ipynb)"
- ]
+ ],
+ "id": "7a4f436e37606967"
}
],
"metadata": {
diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb
index e1121f5b6..7f00cd56e 100644
--- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb
+++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb
@@ -65,8 +65,6 @@
"source": [
"### ``brainpy.math.CustomOpByNumba``\n",
"\n",
- "``brainpy.math.CustomOpByNumba`` is also called ``brainpy.math.XLACustomOp``.\n",
- "\n",
"BrainPy provides ``brainpy.math.CustomOpByNumba`` for customizing the operator on the CPU device. Two parameters are required to provide in ``CustomOpByNumba``:\n",
"\n",
"- ``eval_shape``: evaluates the *shape* and *datatype* of the output argument based on the *shape* and *datatype* of the input argument.\n",
@@ -137,7 +135,7 @@
"collapsed": false
},
"source": [
- "### Return multiple values ``multiple_returns=True``\n",
+ "#### Return multiple values ``multiple_returns=True``\n",
"\n",
"If the result of our computation needs to return multiple arrays, then we need to use ``multiple_returns=True`` in our use of registering the operator. In this case, ``outs`` will be a list containing multiple arrays, not an array.\n",
"\n",
@@ -149,8 +147,10 @@
" return c, d\n",
"\n",
"def con_compute2(outs, ins):\n",
- " c, d = outs # 取出所有的输出\n",
- " a, b = ins # 取出所有的输入\n",
+ " c = outs[0] # take out all the outputs\n",
+ " d = outs[1]\n",
+ " a = ins[0] # take out all the inputs\n",
+ " b = ins[1]\n",
" c[:] = a + 1\n",
" d[:] = a * 2\n",
"\n",
@@ -170,7 +170,7 @@
"collapsed": false
},
"source": [
- "### Non-Tracer parameters\n",
+ "#### Non-Tracer parameters\n",
"\n",
"In the ``eval_shape`` function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by ``jax.jit``. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.\n",
"\n",
@@ -191,7 +191,8 @@
"\n",
"def con_compute3(outs, ins):\n",
" c = outs # Take out all the outputs\n",
- " a, b = ins # Take out all inputs\n",
+ " a = ins[0] # Take out all inputs\n",
+ " b = ins[1]\n",
" c[:] = 2.\n",
"\n",
"op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
@@ -221,7 +222,7 @@
"collapsed": false
},
"source": [
- "### Example: A sparse operator\n",
+ "#### Example: A sparse operator\n",
"\n",
"To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator."
]
@@ -297,6 +298,50 @@
"f(1.)"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### brainpy.math.XLACustomOp\n",
+ "\n",
+ "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit, and then pass the kernel to `brainpy.math.XLACustomOp`.\n",
+ "\n",
+ "Detailed steps are as follows:\n",
+ "\n",
+ "#### Define the kernel\n",
+ "\n",
+ "```python\n",
+ "@numba.njit(fastmath=True)\n",
+ "def numba_event_csrmv(weight, indices, vector, outs):\n",
+ " outs.fill(0)\n",
+ " weight = weight[()] # 0d\n",
+ " for row_i in range(vector.shape[0]):\n",
+ " if vector[row_i]:\n",
+ " for j in indices[row_i]:\n",
+ " outs[j] += weight\n",
+ "```\n",
+ "\n",
+ "In the declaration of parameters, the last few parameters need to be output parameters so that numba can compile correctly. This operator numba_event_csrmv receives four parameters: `weight`, `indices`, `vector`, and `outs`. The first three parameters are input parameters, and the last parameter is the output parameter. The output parameter is a 1D array, and the input parameters are 0D, 1D, and 2D arrays, respectively."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Registering and Using Custom Operators\n",
+ "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n",
+ "\n",
+ "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n",
+ "\n",
+ "```python\n",
+ "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n",
+ "indices = bm.random.randint(0, s, (s, 80))\n",
+ "vector = bm.random.rand(s) < 0.1\n",
+ "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n",
+ "print(out)\n",
+ "```"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -423,7 +468,7 @@
"collapsed": false
},
"source": [
- "### 返回多个值 ``multiple_returns=True``\n",
+ "#### 返回多个值 ``multiple_returns=True``\n",
"\n",
"如果我们的计算结果需要返回多个数组,那么,我们在注册算子的使用需要使用``multiple_returns=True``。此时,``outs``将会是一个包含多个数组的列表,而不是一个数组。\n",
"\n",
@@ -434,8 +479,10 @@
" return c, d # 返回多个抽象数组信息\n",
"\n",
"def con_compute2(outs, ins):\n",
- " c, d = outs # 取出所有的输出\n",
- " a, b = ins # 取出所有的输入\n",
+ " c = outs[0] # 取出所有的输出\n",
+ " d = outs[1]\n",
+ " a = ins[0] # 取出所有的输入\n",
+ " b = ins[1]\n",
" c[:] = a + 1\n",
" d[:] = a * 2\n",
"\n",
@@ -455,7 +502,7 @@
"collapsed": false
},
"source": [
- "### 非Tracer参数\n",
+ "#### 非Tracer参数\n",
"\n",
"在``eval_shape``函数中推断数据类型时,如果所有参数都是可以被``jax.jit``追踪的参数,那么所有参数都是抽象信息(只包含形状和类型)。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息,此时我们需要定义非Tracer参数。\n",
"\n",
@@ -476,7 +523,8 @@
"\n",
"def con_compute3(outs, ins):\n",
" c = outs # 取出所有的输出\n",
- " a, b = ins # 取出所有的输入\n",
+ " a = ins[0] # 取出所有的输入\n",
+ " b = ins[1]\n",
" c[:] = 2.\n",
"\n",
"op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
@@ -506,7 +554,7 @@
"collapsed": false
},
"source": [
- "### 示例:一个稀疏算子\n",
+ "#### 示例:一个稀疏算子\n",
"\n",
"为了说明这种方法的有效性,我们在这个定义一个事件驱动的稀疏计算算子。"
]
@@ -581,6 +629,50 @@
"f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n",
"f(1.)"
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### brainpy.math.XLACustomOp\n",
+ "\n",
+ "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using `@numba.jit` or `@numba.njit` decorator, and then pass the kernel to `brainpy.math.XLACustomOp`.\n",
+ "`brainpy.math.XLACustomOp`是一种自定义算子的新方法。它类似于`brainpy.math.CustomOpByNumba`,但它更灵活并支持更高级的特性。如果您想用numba使用这种新方法,只需要使用 `@numba.jit`或`@numba.njit`装饰器定义一个kernel,然后将内核传递给`brainpy.math.XLACustomOp`。\n",
+ "\n",
+ "详细步骤如下:\n",
+ "\n",
+ "#### 定义kernel\n",
+ "在参数声明中,最后几个参数需要是输出参数,这样numba才能正确编译。这个算子`numba_event_csrmv`接受四个参数:weight、indices、vector 和 outs。前三个参数是输入参数,最后一个参数是输出参数。输出参数是一个一维数组,输入参数分别是 0D、1D 和 2D 数组。\n",
+ "\n",
+ "```python\n",
+ "@numba.njit(fastmath=True)\n",
+ "def numba_event_csrmv(weight, indices, vector, outs):\n",
+ " outs.fill(0)\n",
+ " weight = weight[()] # 0d\n",
+ " for row_i in range(vector.shape[0]):\n",
+ " if vector[row_i]:\n",
+ " for j in indices[row_i]:\n",
+ " outs[j] += weight\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 注册并使用自定义算子\n",
+ "在定义了自定义算子之后,可以将其注册到特定框架中,并在需要的地方使用它。在注册时可以指定`cpu_kernel`和`gpu_kernel`,这样算子就可以在不同的设备上运行。并在调用中指定`outs`参数,用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。\n",
+ "\n",
+ "注意: 在算子声明的参数与调用时需要保持顺序的一致。\n",
+ "\n",
+ "```python\n",
+ "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n",
+ "indices = bm.random.randint(0, s, (s, 80))\n",
+ "vector = bm.random.rand(s) < 0.1\n",
+ "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n",
+ "print(out)\n",
+ "```"
+ ]
}
],
"metadata": {
diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb
index 4b86a4269..e927bf72c 100644
--- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb
+++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb
@@ -127,7 +127,7 @@
"metadata": {},
"source": [
"### Registering and Using Custom Operators\n",
- "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.\n",
+ "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n",
"\n",
"Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n",
"\n",
diff --git a/docs/tutorial_building/customize_dynamical_systems.ipynb b/docs/tutorial_building/customize_dynamical_systems.ipynb
index ec19c06a7..f06792419 100644
--- a/docs/tutorial_building/customize_dynamical_systems.ipynb
+++ b/docs/tutorial_building/customize_dynamical_systems.ipynb
@@ -626,7 +626,7 @@
}
],
"source": [
- "runner = bp.dyn.DSRunner(fhn_net,\n",
+ "runner = bp.DSRunner(fhn_net,\n",
" monitors=['f1.v', 'X.v'], \n",
" inputs=[('f1.I', 1.5), # relative access to variable \"I\" in 'fhn1'\n",
" ('X.I', 1.0),]) # absolute access to variable \"I\" in 'fhn2'\n",
diff --git a/docs/tutorial_math/control_flows.ipynb b/docs/tutorial_math/control_flows.ipynb
index c545f55de..fa328783c 100644
--- a/docs/tutorial_math/control_flows.ipynb
+++ b/docs/tutorial_math/control_flows.ipynb
@@ -330,10 +330,10 @@
"TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1]..\n",
"The error occurred while tracing the function for eval_shape. This value became a tracer due to JAX operations on these lines:\n",
"\n",
- " operation a\u001b[35m:f32[]\u001b[39m = convert_element_type[new_dtype=float32 weak_type=False] b\n",
+ " operation a\u001B[35m:f32[]\u001B[39m = convert_element_type[new_dtype=float32 weak_type=False] b\n",
" from line D:\\codes\\projects\\brainpy-chaoming0625\\brainpy\\_src\\math\\ndarray.py:267:19 (__lt__)\n",
"\n",
- " operation a\u001b[35m:bool[1]\u001b[39m = lt b c\n",
+ " operation a\u001B[35m:bool[1]\u001B[39m = lt b c\n",
" from line D:\\codes\\projects\\brainpy-chaoming0625\\brainpy\\_src\\math\\ndarray.py:267:19 (__lt__)\n",
"See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n"
]
diff --git a/docs/tutorial_toolbox/synaptic_weights.ipynb b/docs/tutorial_toolbox/synaptic_weights.ipynb
index 312fa831f..dfa6c5133 100644
--- a/docs/tutorial_toolbox/synaptic_weights.ipynb
+++ b/docs/tutorial_toolbox/synaptic_weights.ipynb
@@ -71,7 +71,8 @@
],
"source": [
"bp.__version__"
- ]
+ ],
+ "id": "23c6c2fe06bbd897"
},
{
"cell_type": "markdown",
diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py
index 0343ae89c..4b3f5f811 100644
--- a/examples/dynamics_simulation/hh_model.py
+++ b/examples/dynamics_simulation/hh_model.py
@@ -43,16 +43,16 @@ def __init__(self, size):
self.KNa.add_elem()
-# hh = HH(1)
-# I, length = bp.inputs.section_input(values=[0, 5, 0],
-# durations=[100, 500, 100],
-# return_length=True)
-# runner = bp.DSRunner(
-# hh,
-# monitors=['V', 'INa.p', 'INa.q', 'IK.p'],
-# inputs=[hh.input, I, 'iter'],
-# )
-# runner.run(length)
-#
-# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
+hh = HH(1)
+I, length = bp.inputs.section_input(values=[0, 5, 0],
+ durations=[100, 500, 100],
+ return_length=True)
+runner = bp.DSRunner(
+ hh,
+ monitors=['V', 'INa.p', 'INa.q', 'IK.p'],
+ inputs=[hh.input, I, 'iter'],
+)
+runner.run(length)
+
+bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 641f99fde..5e9c95a32 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,14 +1,13 @@
numpy
-brainpylib
jax
jaxlib
matplotlib
msgpack
tqdm
pathos
-taichi
+braintaichi
numba
-braincore
+brainstate
braintools
diff --git a/requirements-doc.txt b/requirements-doc.txt
index 8b0a5a6a4..5c6d440ee 100644
--- a/requirements-doc.txt
+++ b/requirements-doc.txt
@@ -5,12 +5,12 @@ matplotlib
numpy
scipy
numba
-taichi==1.7.0
+braintaichi
# document requirements
pandoc
Jinja2
-sphinx>=5
+sphinx>=5, <8.2.0
myst-nb
sphinx_thebe
sphinx-autodoc-typehints
diff --git a/setup.py b/setup.py
index 55f948e4b..84ac38c12 100644
--- a/setup.py
+++ b/setup.py
@@ -68,9 +68,9 @@
'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html',
],
extras_require={
- 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'],
- 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'],
- 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'],
+ 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'braintaichi'],
+ 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib', 'numba', 'braintaichi'],
+ 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib', 'numba', 'braintaichi'],
'tpu': ['jaxlib[tpu]', 'numba',],
'cpu_mini': ['jaxlib>=0.4.13'],
'cuda11_mini': ['jaxlib[cuda11_pip]'],