diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 048712f05..5e6cc160c 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -1289,7 +1289,7 @@ def __init__( for _ in range(v.batch_axis - len(self.index) + 1)]))) else: index = self.index - self.slice_vars[k] = bm.VariableRef(v, index) + self.slice_vars[k] = bm.VariableView(v, index) # sub-nodes nodes = target.nodes(method='relative', level=1, include_self=False).subset(DynamicalSystem) diff --git a/brainpy/math/controls.py b/brainpy/math/controls.py index 882b06b67..2587b5a79 100644 --- a/brainpy/math/controls.py +++ b/brainpy/math/controls.py @@ -785,6 +785,10 @@ def _body_fun(op): if not isinstance(static_vals, (tuple, list)): static_vals = (static_vals, ) new_vals = body_fun(*static_vals) + if new_vals is None: + new_vals = tuple() + if not isinstance(new_vals, tuple): + new_vals = (new_vals, ) return [v.value for v in dyn_vars], new_vals def _cond_fun(op): diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index ff84ef933..409c400bd 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -16,7 +16,7 @@ 'Variable', 'TrainVar', 'Parameter', - 'VariableRef', + 'VariableView', ] # Ways to change values in a zero-dimensional array @@ -1494,14 +1494,20 @@ def __init__(self, value, dtype=None, batch_axis: int = None): lambda aux_data, flat_contents: Parameter(*flat_contents)) -class VariableRef(Variable): - """A reference of Variable instance.""" +class VariableView(Variable): + """A view of a Variable instance. + + This class is used to create a slice view of ``brainpy.math.Variable``. + + ``VariableView`` can be used to update the subset of the original + Variable instance, and make operations on this subset of the Variable. + """ def __init__(self, value: Variable, index): self.index = index if not isinstance(value, Variable): raise ValueError('Must be instance of Variable.') temp_shape = tuple([1] * len(index)) - super(VariableRef, self).__init__(jnp.zeros(temp_shape), batch_axis=value.batch_axis) + super(VariableView, self).__init__(jnp.zeros(temp_shape), batch_axis=value.batch_axis) self._value = value @property diff --git a/docs/tutorial_building/build_conductance_neurons.ipynb b/docs/tutorial_building/build_conductance_neurons.ipynb index 232416c49..e19fb751e 100644 --- a/docs/tutorial_building/build_conductance_neurons.ipynb +++ b/docs/tutorial_building/build_conductance_neurons.ipynb @@ -36,9 +36,7 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "outputs": [], + "cell_type": "markdown", "source": [ "On the other hand, simplified models do not care about the physiological features of neurons but mainly focus on how to reproduce the exact spike timing. Therefore, they are more simplified and maybe not biologically explicable.\n", "\n", @@ -47,21 +45,19 @@ "metadata": { "collapsed": false, "pycharm": { - "name": "#%%\n" + "name": "#%% md\n" } } }, { - "cell_type": "code", - "execution_count": null, - "outputs": [], + "cell_type": "markdown", "source": [ "## Building an ion channel" ], "metadata": { "collapsed": false, "pycharm": { - "name": "#%%\n" + "name": "#%% md\n" } } }, @@ -436,17 +432,6 @@ "source": [ "By combining different ion channels, we can get different types of conductance-based neuron models easily and straightforwardly. To see all predifined channel models in BrainPy, please click [here](../apis/dyn.rst)." ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/tutorial_math/control_flows.ipynb b/docs/tutorial_math/control_flows.ipynb index f51c81841..de8719421 100644 --- a/docs/tutorial_math/control_flows.ipynb +++ b/docs/tutorial_math/control_flows.ipynb @@ -55,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "id": "38a2bb50", "metadata": { "pycharm": { @@ -379,13 +379,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "outputs": [ { "data": { "text/plain": "JaxArray([1., 0., 0., 1., 1.], dtype=float32, weak_type=True)" }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -403,13 +403,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "outputs": [ { "data": { - "text/plain": "JaxArray([[0., 0., 1.],\n [1., 1., 0.],\n [0., 0., 0.]], dtype=float32, weak_type=True)" + "text/plain": "JaxArray([[1., 0., 1.],\n [0., 0., 0.],\n [0., 0., 0.]], dtype=float32, weak_type=True)" }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -439,7 +439,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "outputs": [], "source": [ "class OddEvenWhere(bp.Base):\n", @@ -461,13 +461,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "outputs": [ { "data": { "text/plain": "Variable([-1.], dtype=float32)" }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -532,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 13, "outputs": [], "source": [ "class OddEvenCond(bp.Base):\n", @@ -555,13 +555,13 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 14, "outputs": [ { "data": { "text/plain": "Variable([1.], dtype=float32)" }, - "execution_count": 40, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -632,7 +632,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 15, "outputs": [], "source": [ "def f(a):\n", @@ -648,13 +648,13 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 16, "outputs": [ { "data": { "text/plain": "DeviceArray(1., dtype=float32, weak_type=True)" }, - "execution_count": 25, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -671,13 +671,13 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 17, "outputs": [ { "data": { "text/plain": "DeviceArray(2., dtype=float32, weak_type=True)" }, - "execution_count": 26, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -694,13 +694,13 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 18, "outputs": [ { "data": { "text/plain": "DeviceArray(3., dtype=float32, weak_type=True)" }, - "execution_count": 27, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -717,13 +717,13 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 19, "outputs": [ { "data": { "text/plain": "DeviceArray(4., dtype=float32, weak_type=True)" }, - "execution_count": 28, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -740,13 +740,13 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 20, "outputs": [ { "data": { "text/plain": "DeviceArray(5., dtype=float32, weak_type=True)" }, - "execution_count": 29, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -775,7 +775,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 21, "outputs": [], "source": [ "def f2(a, x):\n", @@ -796,13 +796,13 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 22, "outputs": [ { "data": { "text/plain": "DeviceArray(2., dtype=float32, weak_type=True)" }, - "execution_count": 34, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -819,13 +819,13 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 23, "outputs": [ { "data": { "text/plain": "DeviceArray(2., dtype=float32, weak_type=True)" }, - "execution_count": 35, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -842,13 +842,13 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 24, "outputs": [ { "data": { "text/plain": "DeviceArray(0., dtype=float32, weak_type=True)" }, - "execution_count": 36, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -865,13 +865,13 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 25, "outputs": [ { "data": { "text/plain": "DeviceArray(-3., dtype=float32, weak_type=True)" }, - "execution_count": 37, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -888,13 +888,13 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 26, "outputs": [ { "data": { "text/plain": "DeviceArray(5., dtype=float32, weak_type=True)" }, - "execution_count": 38, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -923,7 +923,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 27, "outputs": [ { "name": "stdout", @@ -1008,14 +1008,14 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 28, "outputs": [], "source": [ "class LoopSimple(bp.Base):\n", " def __init__(self):\n", " super(LoopSimple, self).__init__()\n", " rng = bm.random.RandomState(123)\n", - " self.seq = rng.random(1000)\n", + " self.seq = bm.Variable(rng.random(1000))\n", " self.res = bm.Variable(bm.zeros(1))\n", "\n", " def __call__(self):\n", @@ -1032,16 +1032,18 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 29, "outputs": [], "source": [ "import time\n", "\n", - "def measure_time(f):\n", + "def measure_time(f, return_res=False, verbose=True):\n", " t0 = time.time()\n", " r = f()\n", " t1 = time.time()\n", - " print(f'Result: {r}, Time: {t1 - t0}')" + " if verbose:\n", + " print(f'Result: {r}, Time: {t1 - t0}')\n", + " return r if return_res else None" ], "metadata": { "collapsed": false, @@ -1052,13 +1054,13 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 30, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Result: [501.74673], Time: 2.7157142162323\n" + "Result: [501.74664], Time: 0.8628342151641846\n" ] } ], @@ -1077,13 +1079,13 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 31, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Result: [1003.49347], Time: 0.0\n" + "Result: [1003.4931], Time: 0.0\n" ] } ], @@ -1120,8 +1122,8 @@ "\n", "BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using `brainpy.math.Variable`. Specifically, they are:\n", "\n", - "- [brainpy.math.make_loop](https://brainpy.readthedocs.io/en/latest/apis/auto/math/generated/brainpy.math.controls.make_loop.html)\n", - "- [brainpy.math.make_while](https://brainpy.readthedocs.io/en/latest/apis/auto/math/generated/brainpy.math.controls.make_while.html)\n", + "- [brainpy.math.for_loop](https://brainpy.readthedocs.io/en/latest/apis/auto/math/generated/brainpy.math.controls.for_loop.html)\n", + "- [brainpy.math.while_loop](https://brainpy.readthedocs.io/en/latest/apis/auto/math/generated/brainpy.math.controls.while_loop.html)\n", "\n", "In this section, we only talk about how to use our provided loop functions." ], @@ -1135,7 +1137,7 @@ { "cell_type": "markdown", "source": [ - "### ``brainpy.math.make_loop()``" + "### ``brainpy.math.for_loop()``" ], "metadata": { "collapsed": false, @@ -1173,23 +1175,13 @@ { "cell_type": "markdown", "source": [ - "In BrainPy, you can define this logic using ``brainpy.math.make_loop()``:\n", - "\n", - "```python\n", - "\n", - "loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)\n", - "\n", - "hist_of_out_vars = loop_fun(xs)\n", - "```\n", - "\n", - "Or,\n", + "In BrainPy, you can define this logic using ``brainpy.math.for_loop()``:\n", "\n", "```python\n", + "import brainpy.math\n", "\n", - "loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)\n", - "\n", - "hist_of_out_vars, hist_of_return_vars = loop_fun(xs)\n", - "```\n" + "hist_of_out_vars = brainpy.math.for_loop(body_fun, dyn_vars, operands)\n", + "```" ], "metadata": { "collapsed": false, @@ -1201,7 +1193,7 @@ { "cell_type": "markdown", "source": [ - "For the above example, we can rewrite it by using ``brainpy.math.make_loop`` as:" + "For the above example, we can rewrite it by using ``brainpy.math.for_loop`` as:" ], "metadata": { "collapsed": false, @@ -1212,7 +1204,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 32, "outputs": [], "source": [ "class LoopStruct(bp.Base):\n", @@ -1222,12 +1214,12 @@ " self.seq = rng.random(1000)\n", " self.res = bm.Variable(bm.zeros(1))\n", "\n", - " def add(s): self.res += s\n", - " self.loop = bm.make_loop(add, dyn_vars=[self.res])\n", - "\n", " def __call__(self):\n", - " self.loop(self.seq)\n", - " return self.res.value" + " def add(s):\n", + " self.res += s\n", + " return self.res.value\n", + "\n", + " return bm.for_loop(body_fun=add, dyn_vars=[self.res], operands=self.seq)" ], "metadata": { "collapsed": false, @@ -1238,21 +1230,22 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 33, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result: [501.74664], Time: 0.028011560440063477\n" - ] + "data": { + "text/plain": "(1000, 1)" + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "model = bm.jit(LoopStruct())\n", "\n", - "# First time will trigger compilation\n", - "measure_time(model)" + "r = measure_time(model, verbose=False, return_res=True)\n", + "r.shape" ], "metadata": { "collapsed": false, @@ -1262,32 +1255,23 @@ } }, { - "cell_type": "code", - "execution_count": 55, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result: [1003.4931], Time: 0.0\n" - ] - } - ], + "cell_type": "markdown", "source": [ - "# Second running\n", - "measure_time(model)" + "In essence, ``body_fun`` defines the one-step updating rule of how variables are updated. All returns of ``body_fun`` will be gathered as the history values.\n", + "``dyn_vars`` defines all dynamical variables used in the ``body_fun``.\n", + "``operands`` specified the inputs of the ``body_fun``. It will be looped over the fist axis." ], "metadata": { "collapsed": false, "pycharm": { - "name": "#%%\n" + "name": "#%% md\n" } } }, { "cell_type": "markdown", "source": [ - "### ``brainpy.math.make_while()``" + "### ``brainpy.math.while_loop()``" ], "metadata": { "collapsed": false, @@ -1299,7 +1283,7 @@ { "cell_type": "markdown", "source": [ - "``brainpy.math.make_while()`` is used to generate a while-loop function when you use ``JaxArray``. It supports the following loop logic:\n", + "``brainpy.math.while_loop()`` is used to generate a while-loop function when you use ``Varible``. It supports the following loop logic:\n", "\n", "```python\n", "\n", @@ -1307,27 +1291,20 @@ " statements\n", "```\n", "\n", - "When using ``brainpy.math.make_while()`` , *condition* should be wrapped as a ``cond_fun`` function which returns a boolean value, and *statements* should be packed as a ``body_fun`` function which does not support returned values:\n", + "When using ``brainpy.math.while_loop()`` , *condition* should be wrapped as a ``cond_fun`` function which returns a boolean value, and *statements* should be packed as a ``body_fun`` function which receives the old values at the latest step and returns the updated values at the current step:\n", "\n", "```python\n", "\n", "while cond_fun(x):\n", - " body_fun(x)\n", + " x = body_fun(x)\n", "```\n", "\n", - "where ``x`` is the external input that is not iterated. All the iterated variables should be marked as ``JaxArray``. All ``JaxArray``s used in ``cond_fun`` and ``body_fun`` should be declared as ``dyn_vars`` variables." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "Let's look an example:" + "Note the difference between ``brainpy.math.for_loop`` and ``brainpy.math.while_loop``:\n", + "\n", + "1. The returns of ``brainpy.math.for_loop`` are the values to be gathered as the history values. While the returns of ``brainpy.math.while_loop`` should be the same shape and type with its inputs, because they are represented as the updated values.\n", + "2. ``brainpy.math.for_loop`` can receive anything without explicit requirements of returns. But, ``brainpy.math.while_loop`` should return what it receives.\n", + "\n", + "A concreate example of ``brainpy.math.while_loop`` is as the follows:" ], "metadata": { "collapsed": false, @@ -1338,20 +1315,29 @@ }, { "cell_type": "code", - "execution_count": 56, - "outputs": [], + "execution_count": 34, + "outputs": [ + { + "data": { + "text/plain": "()" + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "i = bm.Variable(bm.zeros(1))\n", "counter = bm.Variable(bm.zeros(1))\n", "\n", - "def cond_f(x):\n", + "def cond_f():\n", " return i[0] < 10\n", "\n", - "def body_f(x):\n", + "def body_f():\n", " i.value += 1.\n", " counter.value += i\n", "\n", - "loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])" + "bm.while_loop(body_f, cond_f, dyn_vars=[i, counter], operands=())" ], "metadata": { "collapsed": false, @@ -1374,10 +1360,19 @@ }, { "cell_type": "code", - "execution_count": 57, - "outputs": [], + "execution_count": 35, + "outputs": [ + { + "data": { + "text/plain": "Variable([55.], dtype=float32)" + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "loop()" + "counter" ], "metadata": { "collapsed": false, @@ -1388,19 +1383,19 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 36, "outputs": [ { "data": { - "text/plain": "Variable([55.], dtype=float32)" + "text/plain": "Variable([10.], dtype=float32)" }, - "execution_count": 58, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "counter" + "i" ], "metadata": { "collapsed": false, @@ -1409,21 +1404,42 @@ } } }, + { + "cell_type": "markdown", + "source": [ + "Or, similarly," + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 38, "outputs": [ { "data": { - "text/plain": "Variable([10.], dtype=float32)" + "text/plain": "(DeviceArray(56., dtype=float32),)" }, - "execution_count": 59, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "i" + "i = bm.Variable(bm.zeros(1))\n", + "\n", + "def cond_f(counter):\n", + " return i[0] < 10\n", + "\n", + "def body_f(counter):\n", + " i.value += 1.\n", + " return counter + i[0]\n", + "\n", + "bm.while_loop(body_f, cond_f, dyn_vars=[i], operands=(1., ))" ], "metadata": { "collapsed": false,