Skip to content

Commit

Permalink
upgrade operators to match brainpylib>=0.1.0 (#284)
Browse files Browse the repository at this point in the history
upgrade operators to match brainpylib>=0.1.0
  • Loading branch information
chaoming0625 authored Nov 4, 2022
2 parents 15828ca + c12d316 commit f3e7d72
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 95 deletions.
7 changes: 4 additions & 3 deletions brainpy/math/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# -*- coding: utf-8 -*-


from . import multiplication
from . import sparse_matmul, event_matmul
from . import op_register
from . import pre_syn_post as pre_syn_post_module
from . import wrap_jax
from . import spikegrad

__all__ = multiplication.__all__ + op_register.__all__
__all__ = event_matmul.__all__ + sparse_matmul.__all__ + op_register.__all__
__all__ += pre_syn_post_module.__all__ + wrap_jax.__all__ + spikegrad.__all__


from .multiplication import *
from .event_matmul import *
from .sparse_matmul import *
from .op_register import *
from .pre_syn_post import *
from .wrap_jax import *
Expand Down
57 changes: 57 additions & 0 deletions brainpy/math/operators/event_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-


from typing import Tuple

from brainpy.math.numpy_ops import as_jax
from brainpy.types import Array
from .utils import _check_brainpylib

try:
import brainpylib
except ModuleNotFoundError:
brainpylib = None

__all__ = [
'event_csr_matvec',
]


def event_csr_matvec(values: Array,
indices: Array,
indptr: Array,
events: Array,
shape: Tuple[int, ...],
transpose: bool = False):
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.
Parameters
----------
values: Array, float
An array of shape ``(nse,)`` or a float.
indices: Array
An array of shape ``(nse,)``.
indptr: Array
An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
events: Array
An array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``.
shape: tuple of int
A length-2 tuple representing the sparse matrix shape.
transpose: bool
A boolean specifying whether to transpose the sparse matrix
before computing. Default is False.
Returns
-------
out: Array
A tensor with the shape of ``shape[1]`` if `transpose=True`,
or ``shape[0]`` if `transpose=False`.
"""
_check_brainpylib('event_csr_matvec')
events = as_jax(events)
indices = as_jax(indices)
indptr = as_jax(indptr)
values = as_jax(values)
return brainpylib.event_csr_matvec(values, indices, indptr, events,
shape=shape, transpose=transpose)
36 changes: 23 additions & 13 deletions brainpy/math/operators/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Union, Sequence, Callable

from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.tree_util import tree_map

from brainpy.base import Base
Expand Down Expand Up @@ -57,6 +57,10 @@ def __init__(
gpu_func: Callable = None,
apply_cpu_func_to_gpu: bool = False,
name: str = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
multiple_results: bool = False,
):
_check_brainpylib(register_op.__name__)
super(XLACustomOp, self).__init__(name=name)
Expand All @@ -77,19 +81,25 @@ def __init__(
gpu_func = None

# register OP
self.op = brainpylib.register_op(self.name,
cpu_func=cpu_func,
gpu_func=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
self.op = brainpylib.register_op_with_numba(
self.name,
cpu_func=cpu_func,
gpu_func_translation=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)

def __call__(self, *args, **kwargs):
args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
args, is_leaf=lambda a: isinstance(a, JaxArray))
kwargs = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
kwargs, is_leaf=lambda a: isinstance(a, JaxArray))
res = self.op.bind(*args, **kwargs)
return res[0] if len(res) == 1 else res
return res


def register_op(
Expand Down Expand Up @@ -122,15 +132,15 @@ def register_op(
A jitable JAX function.
"""
_check_brainpylib(register_op.__name__)
f = brainpylib.register_op(name,
cpu_func=cpu_func,
gpu_func=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
f = brainpylib.register_op_with_numba(name,
cpu_func=cpu_func,
gpu_func_translation=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)

def fixed_op(*inputs, **info):
inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs])
res = f.bind(*inputs, **info)
return res[0] if len(res) == 1 else res
return res

return fixed_op
Loading

0 comments on commit f3e7d72

Please sign in to comment.