From 520767a27d3b1d0bf789fa7b628ac6a7f7ca46ae Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 2 Oct 2022 23:36:56 +0800 Subject: [PATCH 1/4] update `VariableView` --- brainpy/math/jaxarray.py | 66 ++++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 196bf8ace..0c4c0a6f7 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -1528,43 +1528,43 @@ def __init__(self, value_or_size, dtype=None, batch_axis: int = None): class VariableView(Variable): """A view of a Variable instance. - This class is used to create a slice view of ``brainpy.math.Variable``. + This class is used to create a subset view of ``brainpy.math.Variable``. + + >>> import brainpy.math as bm + >>> bm.random.seed(123) + >>> origin = bm.Variable(bm.random.random(5)) + >>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2] + VariableView([0.02920651, 0.19066381], dtype=float32) ``VariableView`` can be used to update the subset of the original Variable instance, and make operations on this subset of the Variable. + + >>> view[:] = 1. + >>> view + VariableView([1., 1.], dtype=float32) + >>> origin + Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32) + >>> view + 10 + DeviceArray([11., 11.], dtype=float32) + >>> view *= 10 + VariableView([10., 10.], dtype=float32) + + The above example demonstrates that the updating of an ``VariableView`` instance + is actually made in the original ``Variable`` instance. + + Moreover, it's worthy to note that ``VariableView`` is not a PyTree. """ def __init__(self, value: Variable, index): self.index = index if not isinstance(value, Variable): raise ValueError('Must be instance of Variable.') - temp_shape = tuple([1] * len(index)) - super(VariableView, self).__init__(jnp.zeros(temp_shape), batch_axis=value.batch_axis) + super(VariableView, self).__init__(value.value, batch_axis=value.batch_axis) self._value = value @property def value(self): return self._value[self.index] - @value.setter - def value(self, value): - int_shape = self.shape - if self.batch_axis is None: - ext_shape = value.shape - else: - ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:] - int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:] - if ext_shape != int_shape: - error = f"The shape of the original data is {int_shape}, while we got {value.shape}" - if self.batch_axis is None: - error += '. Do you forget to set "batch_axis" when initialize this variable?' - else: - error += f' with batch_axis={self.batch_axis}.' - raise MathError(error) - if value.dtype != self._value.dtype: - raise MathError(f"The dtype of the original data is {self._value.dtype}, " - f"while we got {value.dtype}.") - self._value[self.index] = value - def __setitem__(self, index, value): # value is JaxArray if isinstance(value, JaxArray): @@ -1653,3 +1653,23 @@ def fill(self, value): def sort(self, axis=-1, kind=None, order=None): """Sort an array in-place.""" self._value[self.index] = self.value.sort(axis=axis, kind=kind, order=order) + + def update(self, value): + if self.batch_axis is None: + ext_shape = value.shape + int_shape = self.shape + else: + ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:] + int_shape = self.shape[:self.batch_axis] + self.shape[self.batch_axis + 1:] + if ext_shape != int_shape: + error = f"The shape of the original data is {self.shape}, while we got {value.shape}" + if self.batch_axis is None: + error += '. Do you forget to set "batch_axis" when initialize this variable?' + else: + error += f' with batch_axis={self.batch_axis}.' + raise MathError(error) + if value.dtype != self._value.dtype: + raise MathError(f"The dtype of the original data is {self._value.dtype}, " + f"while we got {value.dtype}.") + self._value[self.index] = value.value if isinstance(value, JaxArray) else value + From c743f2bec623130be3567be255bb129890120a54 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 4 Oct 2022 16:49:38 +0800 Subject: [PATCH 2/4] support to set analysis plotting styles --- brainpy/analysis/__init__.py | 4 +- brainpy/analysis/lowdim/lowdim_bifurcation.py | 16 +++-- brainpy/analysis/lowdim/lowdim_phase_plane.py | 22 +++--- brainpy/analysis/plotstyle.py | 72 +++++++++++++++++++ brainpy/analysis/stability.py | 32 +-------- 5 files changed, 96 insertions(+), 50 deletions(-) create mode 100644 brainpy/analysis/plotstyle.py diff --git a/brainpy/analysis/__init__.py b/brainpy/analysis/__init__.py index 48a34d9ca..9ea60642a 100644 --- a/brainpy/analysis/__init__.py +++ b/brainpy/analysis/__init__.py @@ -22,6 +22,4 @@ from .lowdim.lowdim_bifurcation import * from .constants import * -from . import constants as C -from . import stability -from . import utils +from . import constants as C, stability, plotstyle, utils diff --git a/brainpy/analysis/lowdim/lowdim_bifurcation.py b/brainpy/analysis/lowdim/lowdim_bifurcation.py index 3ac4b8487..e8431d1bd 100644 --- a/brainpy/analysis/lowdim/lowdim_bifurcation.py +++ b/brainpy/analysis/lowdim/lowdim_bifurcation.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy import errors -from brainpy.analysis import stability, utils, constants as C +from brainpy.analysis import stability, plotstyle, utils, constants as C from brainpy.analysis.lowdim.lowdim_analyzer import * pyplot = None @@ -79,8 +79,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, pyplot.figure(self.x_var) for fp_type, points in container.items(): if len(points['x']): - plot_style = stability.plot_scheme[fp_type] - pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type) + plot_style = plotstyle.plot_schema[fp_type] + pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type) pyplot.xlabel(self.target_par_names[0]) pyplot.ylabel(self.x_var) @@ -107,10 +107,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, ax = fig.add_subplot(projection='3d') for fp_type, points in container.items(): if len(points['x']): - plot_style = stability.plot_scheme[fp_type] + plot_style = plotstyle.plot_schema[fp_type] xs = points['p0'] ys = points['p1'] zs = points['x'] + plot_style.pop('linestyle') ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.target_par_names[0]) @@ -298,8 +299,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, pyplot.figure(var) for fp_type, points in container.items(): if len(points['p']): - plot_style = stability.plot_scheme[fp_type] - pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type) + plot_style = plotstyle.plot_schema[fp_type] + pyplot.plot(points['p'], points[var], **plot_style, label=fp_type) pyplot.xlabel(self.target_par_names[0]) pyplot.ylabel(var) @@ -330,10 +331,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, ax = fig.add_subplot(projection='3d') for fp_type, points in container.items(): if len(points['p0']): - plot_style = stability.plot_scheme[fp_type] + plot_style = plotstyle.plot_schema[fp_type] xs = points['p0'] ys = points['p1'] zs = points[var] + plot_style.pop('linestyle') ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.target_par_names[0]) diff --git a/brainpy/analysis/lowdim/lowdim_phase_plane.py b/brainpy/analysis/lowdim/lowdim_phase_plane.py index 693d93f7d..05259b4e6 100644 --- a/brainpy/analysis/lowdim/lowdim_phase_plane.py +++ b/brainpy/analysis/lowdim/lowdim_phase_plane.py @@ -6,7 +6,7 @@ import brainpy.math as bm from brainpy import errors, math -from brainpy.analysis import stability, constants as C, utils +from brainpy.analysis import stability, plotstyle, constants as C, utils from brainpy.analysis.lowdim.lowdim_analyzer import * pyplot = None @@ -107,8 +107,8 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False): if with_plot: for fp_type, points in container.items(): if len(points): - plot_style = stability.plot_scheme[fp_type] - pyplot.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type) + plot_style = plotstyle.plot_schema[fp_type] + pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type) pyplot.legend() if show: pyplot.show() @@ -248,9 +248,9 @@ def plot_nullcline(self, with_plot=True, with_return=False, if with_plot: if x_style is None: - x_style = dict(color='cornflowerblue', alpha=.7, ) - fmt = x_style.pop('fmt', '.') - pyplot.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline") + x_style = dict(color='cornflowerblue', alpha=.7, fmt='.') + line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple() + pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline") # Nullcline of the y variable utils.output('I am computing fy-nullcline ...') @@ -260,9 +260,9 @@ def plot_nullcline(self, with_plot=True, with_return=False, if with_plot: if y_style is None: - y_style = dict(color='lightcoral', alpha=.7, ) - fmt = y_style.pop('fmt', '.') - pyplot.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline") + y_style = dict(color='lightcoral', alpha=.7, fmt='.') + line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple() + pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline") if with_plot: pyplot.xlabel(self.x_var) @@ -349,8 +349,8 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False, if with_plot: for fp_type, points in container.items(): if len(points['x']): - plot_style = stability.plot_scheme[fp_type] - pyplot.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type) + plot_style = plotstyle.plot_schema[fp_type] + pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type) pyplot.legend() if show: pyplot.show() diff --git a/brainpy/analysis/plotstyle.py b/brainpy/analysis/plotstyle.py new file mode 100644 index 000000000..50c568a15 --- /dev/null +++ b/brainpy/analysis/plotstyle.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + + +__all__ = [ + 'plot_schema', + 'set_plot_schema', +] + +from .stability import (CENTER_MANIFOLD, SADDLE_NODE, STABLE_POINT_1D, + UNSTABLE_POINT_1D, CENTER_2D, STABLE_NODE_2D, + STABLE_FOCUS_2D, STABLE_STAR_2D, STABLE_DEGENERATE_2D, + UNSTABLE_NODE_2D, UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D, + UNSTABLE_DEGENERATE_2D, UNSTABLE_LINE_2D, + STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D, + UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, STABLE_FOCUS_3D, + UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D) + + +_markersize = 20 + +plot_schema = {} + +plot_schema[CENTER_MANIFOLD] = {'color': 'orangered', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'} +plot_schema[SADDLE_NODE] = {"color": 'tab:blue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'} + +plot_schema[STABLE_POINT_1D] = {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'} +plot_schema[UNSTABLE_POINT_1D] = {"color": 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'} + +plot_schema.update({ + CENTER_2D: {'color': 'lime', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_NODE_2D: {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_FOCUS_2D: {"color": 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_STAR_2D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_DEGENERATE_2D: {'color': 'blueviolet', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_NODE_2D: {"color": 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_FOCUS_2D: {"color": 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_STAR_2D: {'color': 'green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_DEGENERATE_2D: {'color': 'springgreen', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_LINE_2D: {'color': 'dodgerblue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, +}) + + +plot_schema.update({ + STABLE_POINT_3D: {'color': 'tab:gray', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_POINT_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_NODE_3D: {'color': 'tab:green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_SADDLE_3D: {'color': 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_FOCUS_3D: {'color': 'tab:pink', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_FOCUS_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_NODE_3D: {'color': 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_CENTER_3D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNKNOWN_3D: {'color': 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, +}) + + +def set_plot_schema(fixed_point: str, **schema): + if not isinstance(fixed_point, str): + raise TypeError(f'Must instance of string, but we got {type(fixed_point)}: {fixed_point}') + if fixed_point not in plot_schema: + raise KeyError(f'Fixed point type {fixed_point} does not found in the built-in types. ') + plot_schema[fixed_point].update(**schema) + + +def set_markersize(markersize): + if not isinstance(markersize, int): + raise TypeError(f"Must be an integer, but got {type(markersize)}: {markersize}") + global _markersize + __markersize = markersize + for key in tuple(plot_schema.keys()): + plot_schema[key]['markersize'] = markersize + + diff --git a/brainpy/analysis/stability.py b/brainpy/analysis/stability.py index 43a5f4ab9..156168098 100644 --- a/brainpy/analysis/stability.py +++ b/brainpy/analysis/stability.py @@ -6,7 +6,7 @@ 'get_1d_stability_types', 'get_2d_stability_types', 'get_3d_stability_types', - 'plot_scheme', + 'stability_analysis', @@ -27,17 +27,13 @@ 'UNSTABLE_LINE_2D', ] -plot_scheme = {} + SADDLE_NODE = 'saddle node' CENTER_MANIFOLD = 'center manifold' -plot_scheme[CENTER_MANIFOLD] = {'color': 'orangered'} -plot_scheme[SADDLE_NODE] = {"color": 'tab:blue'} STABLE_POINT_1D = 'stable point' UNSTABLE_POINT_1D = 'unstable point' -plot_scheme[STABLE_POINT_1D] = {"color": 'tab:red'} -plot_scheme[UNSTABLE_POINT_1D] = {"color": 'tab:olive'} CENTER_2D = 'center' STABLE_NODE_2D = 'stable node' @@ -49,18 +45,7 @@ UNSTABLE_STAR_2D = 'unstable star' UNSTABLE_DEGENERATE_2D = 'unstable degenerate' UNSTABLE_LINE_2D = 'unstable line' -plot_scheme.update({ - CENTER_2D: {'color': 'lime'}, - STABLE_NODE_2D: {"color": 'tab:red'}, - STABLE_FOCUS_2D: {"color": 'tab:purple'}, - STABLE_STAR_2D: {'color': 'tab:olive'}, - STABLE_DEGENERATE_2D: {'color': 'blueviolet'}, - UNSTABLE_NODE_2D: {"color": 'tab:orange'}, - UNSTABLE_FOCUS_2D: {"color": 'tab:cyan'}, - UNSTABLE_STAR_2D: {'color': 'green'}, - UNSTABLE_DEGENERATE_2D: {'color': 'springgreen'}, - UNSTABLE_LINE_2D: {'color': 'dodgerblue'}, -}) + STABLE_POINT_3D = 'unclassified stable point' UNSTABLE_POINT_3D = 'unclassified unstable point' @@ -71,17 +56,6 @@ UNSTABLE_FOCUS_3D = 'unstable focus' UNSTABLE_CENTER_3D = 'unstable center' UNKNOWN_3D = 'unknown 3d' -plot_scheme.update({ - STABLE_POINT_3D: {'color': 'tab:gray'}, - UNSTABLE_POINT_3D: {'color': 'tab:purple'}, - STABLE_NODE_3D: {'color': 'tab:green'}, - UNSTABLE_SADDLE_3D: {'color': 'tab:red'}, - UNSTABLE_FOCUS_3D: {'color': 'tab:pink'}, - STABLE_FOCUS_3D: {'color': 'tab:purple'}, - UNSTABLE_NODE_3D: {'color': 'tab:orange'}, - UNSTABLE_CENTER_3D: {'color': 'tab:olive'}, - UNKNOWN_3D: {'color': 'tab:cyan'}, -}) def get_1d_stability_types(): From 4ca22f6c3ab5819a3220aa3bb065c382e7235699 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 4 Oct 2022 16:55:28 +0800 Subject: [PATCH 3/4] update others --- brainpy/dyn/rates/populations.py | 24 +- docs/quickstart/installation.rst | 2 +- .../overview_of_dynamic_model.ipynb | 295 +++--------------- .../Wang_2002_decision_making_spiking.py | 1 - 4 files changed, 65 insertions(+), 257 deletions(-) diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index 0c3d90739..e48a0925a 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -180,8 +180,8 @@ def update(self, tdi, x=None): self.y.value = y def clear_input(self): - self.input[:] = 0. - self.input_y[:] = 0. + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class FeedbackFHN(RateModel): @@ -375,8 +375,8 @@ def update(self, tdi, x=None): self.y.value = y def clear_input(self): - self.input[:] = 0. - self.input_y[:] = 0. + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class QIF(RateModel): @@ -558,8 +558,8 @@ def update(self, tdi, x=None): self.y.value = y def clear_input(self): - self.input[:] = 0. - self.input_y[:] = 0. + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class StuartLandauOscillator(RateModel): @@ -700,8 +700,8 @@ def update(self, tdi, x=None): self.y.value = y def clear_input(self): - self.input[:] = 0. - self.input_y[:] = 0. + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class WilsonCowanModel(RateModel): @@ -857,8 +857,8 @@ def update(self, tdi, x=None): self.y.value = y def clear_input(self): - self.input[:] = 0. - self.input_y[:] = 0. + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class JansenRitModel(RateModel): @@ -976,5 +976,5 @@ def update(self, tdi, x=None): self.i.value = bm.maximum(self.i + di * dt, 0.) def clear_input(self): - self.Ie[:] = 0. - self.Ii[:] = 0. + self.Ie.value = bm.zeros_like(self.Ie) + self.Ii.value = bm.zeros_like(self.Ii) diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index e3db95795..d0aeebcfa 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -159,7 +159,7 @@ Many customized operators in BrainPy are implemented in ``brainpylib``. For GPU operators, you should compile ``brainpylib`` from source. The details please see -`Compile GPU operators in brainpylib <../tutorial_advanced/compile_brainpylib.rst>`_. +`Compile GPU operators in brainpylib <../tutorial_advanced/compile_brainpylib.html>`_. Other Dependency diff --git a/docs/tutorial_building/overview_of_dynamic_model.ipynb b/docs/tutorial_building/overview_of_dynamic_model.ipynb index 9e7c09c4a..9a3f4ecae 100644 --- a/docs/tutorial_building/overview_of_dynamic_model.ipynb +++ b/docs/tutorial_building/overview_of_dynamic_model.ipynb @@ -3,10 +3,7 @@ { "cell_type": "markdown", "metadata": { - "collapsed": true, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": true }, "source": [ "# Utilizing Built-in Models" @@ -14,22 +11,14 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ " @[Tianqiu Zhang](mailto:tianqiuakita@gmail.com) @[Chaoming Wang](mailto:adaduo@outlook.com)" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "BrainPy enables modularity programming and easy model debugging. To build a complex brain dynamics model, you just need to group its building blocks. In this section, we are going to talk about what building blocks we provide, and how to use these building blocks.\n" ] @@ -37,11 +26,7 @@ { "cell_type": "code", "execution_count": 1, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "import brainpy as bp\n", @@ -52,22 +37,14 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Models in ``brainpy.dyn``" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "``brainpy.dyn`` has provided many convenient channels, neurons, synapse, and other models for users. The following figure is a glimpse of the provided models.\n", "\n", @@ -76,33 +53,21 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "New models will be continuously updated in the page of [API documentation](../apis/dyn.rst)." ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Initializing a neuron model" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "All neuron models implemented in brainpy are subclasses of ``brainpy.dyn.NeuGroup``. The initialization of a neuron model just needs to provide the geometry size of neurons in a population group." ] @@ -110,11 +75,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "hh = bp.neurons.HH(size=1) # only 1 neuron\n", @@ -128,11 +89,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Generally speaking, there are two types of arguments can be set by users:\n", "\n", @@ -145,11 +102,7 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -170,11 +123,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "However, neuron models support heterogeneous parameters when performing computations in a neuron group. One can initialize *heterogeneous parameters* by several ways.\n", "\n", @@ -186,11 +135,7 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -211,11 +156,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "**2\\. Initializer**\n", "\n", @@ -225,11 +166,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -250,11 +187,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "**3\\. Callable function**\n", "\n", @@ -264,11 +197,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -289,11 +218,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Here, let's see how the heterogeneous parameters influence our model simulation." ] @@ -301,11 +226,7 @@ { "cell_type": "code", "execution_count": 7, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# we create 3 neurons in a group. Each neuron has a unique \"gNa\"\n", @@ -316,11 +237,7 @@ { "cell_type": "code", "execution_count": 9, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -358,11 +275,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Similarly, the setting of the initial values of a variable can also be realized through the above three ways: *Array*, *Initializer*, and *Callable function*. For example," ] @@ -370,11 +283,7 @@ { "cell_type": "code", "execution_count": 10, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "hh = bp.neurons.HH(\n", @@ -388,11 +297,7 @@ { "cell_type": "code", "execution_count": 11, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -412,22 +317,14 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Initializing a synapse model" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Initializing a synapse model needs to provide its pre-synaptic group (``pre``), post-synaptic group (``post``) and the connection method between them (``conn``). The below is an example to create an [Exponential synapse model](../apis/auto/dyn/generated/brainpy.dyn.synapses.ExpCUBA.rst):" ] @@ -435,26 +332,18 @@ { "cell_type": "code", "execution_count": 13, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "neu = bp.neurons.LIF(10)\n", "\n", "# here we create a synaptic projection within a population\n", - "syn = bp.synapses.compat.ExpCUBA(pre=neu, post=neu, conn=bp.conn.All2All())" + "syn = bp.synapses.Exponential(pre=neu, post=neu, conn=bp.conn.All2All())" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "BrainPy's build-in synapse models support **heterogeneous** synaptic weights and delay steps by using *Array*, *Initializer* and *Callable function*. For example," ] @@ -462,26 +351,18 @@ { "cell_type": "code", "execution_count": 14, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ - "syn = bp.synapses.compat.ExpCUBA(neu, neu, bp.conn.FixedProb(prob=0.1),\n", - " g_max=bp.init.Uniform(min_val=0.1, max_val=1.),\n", - " delay_step=lambda shape: bm.random.randint(10, 30, shape))" + "syn = bp.synapses.Exponential(neu, neu, bp.conn.FixedProb(prob=0.1),\n", + " g_max=bp.init.Uniform(min_val=0.1, max_val=1.),\n", + " delay_step=lambda shape: bm.random.randint(10, 30, shape))" ] }, { "cell_type": "code", "execution_count": 15, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -503,11 +384,7 @@ { "cell_type": "code", "execution_count": 16, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -526,55 +403,35 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "However, in BrainPy, the built-in synapse models only support homogenous synaptic parameters, like the time constant $\\tau$. Users can [customize their synaptic models](./synapse_models.ipynb) when they want heterogeneous synaptic parameters." ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Similar, the synaptic variables can be initialized heterogeneously by using *Array*, *Initializer*, and *Callable functions*." ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Changing model parameters during simulation" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "In BrainPy, all the dynamically changed variables (no matter it is changed inside or outside the jitted function) should be marked as ``brainpy.math.Variable``. BrainPy's built-in models also support modifying model parameters during simulation." ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "For example, if you want to fix the `gNa` in the first 100 ms simulation, and then try to decrease its value in the following simulations. In this case, we can provide the `gNa` as an instance of ``brainpy.math.Variable`` when initializing the model." ] @@ -582,11 +439,7 @@ { "cell_type": "code", "execution_count": 17, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "hh = bp.neurons.HH(5, gNa=bm.Variable(bm.asarray([120.])))\n", @@ -597,11 +450,7 @@ { "cell_type": "code", "execution_count": 18, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -639,11 +488,7 @@ { "cell_type": "code", "execution_count": 19, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -683,22 +528,14 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Examples of using built-in models" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Here we show users how to simulate a famous neuron models: [The Morris-Lecar neuron model](../apis/auto/dyn/generated/brainpy.dyn.neurons.MorrisLecar.rst), which is a two-dimensional \"reduced\" excitation model applicable to systems having two non-inactivating voltage-sensitive conductances." ] @@ -706,11 +543,7 @@ { "cell_type": "code", "execution_count": 20, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "group = bp.neurons.MorrisLecar(1)" @@ -718,11 +551,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Then users can utilize various tools provided by BrainPy to easily simulate the Morris-Lecar neuron model. Here we are not going to dive into details so please read the corresponding tutorials if you want to learn more." ] @@ -730,11 +559,7 @@ { "cell_type": "code", "execution_count": 21, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -776,11 +601,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Next we will also give users an intuitive understanding about building a network composed of different neurons and synapses model. Users can simply initialize these models as below and pass into ``brainpy.dyn.Network``." ] @@ -788,11 +609,7 @@ { "cell_type": "code", "execution_count": 24, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "neu1 = bp.neurons.HH(1)\n", @@ -803,11 +620,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "By selecting proper runner, users can simulate the network efficiently and plot the simulation results." ] @@ -815,11 +628,7 @@ { "cell_type": "code", "execution_count": 25, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { diff --git a/examples/simulation/Wang_2002_decision_making_spiking.py b/examples/simulation/Wang_2002_decision_making_spiking.py index c600bc1b4..66cfd9a68 100644 --- a/examples/simulation/Wang_2002_decision_making_spiking.py +++ b/examples/simulation/Wang_2002_decision_making_spiking.py @@ -72,7 +72,6 @@ def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, mode=bp.modes.Norm nmda_par = dict(delay_step=int(0.5 / bm.get_dt()), tau_decay=100, tau_rise=2., a=0.5) # E neurons/pyramid neurons - A = bp.neurons.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, tau_ref=2., V_initializer=bp.init.OneInit(-70.), mode=mode) B = bp.neurons.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, From 536c3a8c597f11a7c95d010b78935d9bf9cb989b Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 4 Oct 2022 17:24:36 +0800 Subject: [PATCH 4/4] fix bugs in `LengthDelay` and `VariableView` --- brainpy/math/delayvars.py | 2 +- brainpy/math/jaxarray.py | 20 +++++++++++++++++ brainpy/math/tests/test_jaxarray.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index 0d525fa96..c709c4804 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -435,7 +435,7 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]): self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step) elif self.update_method == CONCAT_UPDATING: - self.data.value = bm.concatenate([self.data[1:], bm.broadcast_to(value, self.delay_target_shape)], axis=0) + self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])]) else: raise ValueError(f'Unknown updating method "{self.update_method}"') diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 0c4c0a6f7..bb3119569 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -1673,3 +1673,23 @@ def update(self, value): f"while we got {value.dtype}.") self._value[self.index] = value.value if isinstance(value, JaxArray) else value + @value.setter + def value(self, value): + int_shape = self.shape + if self.batch_axis is None: + ext_shape = value.shape + else: + ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:] + int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:] + if ext_shape != int_shape: + error = f"The shape of the original data is {int_shape}, while we got {value.shape}" + if self.batch_axis is None: + error += '. Do you forget to set "batch_axis" when initialize this variable?' + else: + error += f' with batch_axis={self.batch_axis}.' + raise MathError(error) + if value.dtype != self._value.dtype: + raise MathError(f"The dtype of the original data is {self._value.dtype}, " + f"while we got {value.dtype}.") + self._value[self.index] = value.value if isinstance(value, JaxArray) else value + diff --git a/brainpy/math/tests/test_jaxarray.py b/brainpy/math/tests/test_jaxarray.py index 11017b22a..ce000f169 100644 --- a/brainpy/math/tests/test_jaxarray.py +++ b/brainpy/math/tests/test_jaxarray.py @@ -58,3 +58,36 @@ def test_variable_init(self): not bm.array_equal(bm.Variable(bm.random.rand(10)), bm.Variable(10)) ) + + +class TestVariableView(unittest.TestCase): + def test_update(self): + origin = bm.Variable(bm.zeros(10)) + view = bm.VariableView(origin, slice(0, 5, None)) + + view.update(bm.ones(5)) + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.ones(5), bm.zeros(5)])) + ) + + view.value = bm.arange(5.) + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.arange(5), bm.zeros(5)])) + ) + + view += 10 + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)])) + ) + + bm.random.shuffle(view) + print(view) + print(origin) + + view.sort() + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)])) + ) + + self.assertTrue(view.sum() == bm.sum(bm.arange(5) + 10)) +