Skip to content

Commit

Permalink
Update VariableView and analysis plotting apis (#268)
Browse files Browse the repository at this point in the history
Update `VariableView` and analysis plotting apis
  • Loading branch information
chaoming0625 authored Oct 4, 2022
2 parents 99d5cb9 + 536c3a8 commit 174c81a
Show file tree
Hide file tree
Showing 12 changed files with 258 additions and 331 deletions.
4 changes: 1 addition & 3 deletions brainpy/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,4 @@
from .lowdim.lowdim_bifurcation import *

from .constants import *
from . import constants as C
from . import stability
from . import utils
from . import constants as C, stability, plotstyle, utils
16 changes: 9 additions & 7 deletions brainpy/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import brainpy.math as bm
from brainpy import errors
from brainpy.analysis import stability, utils, constants as C
from brainpy.analysis import stability, plotstyle, utils, constants as C
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None
Expand Down Expand Up @@ -79,8 +79,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
pyplot.figure(self.x_var)
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(self.x_var)

Expand All @@ -107,10 +107,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
plot_style = plotstyle.plot_schema[fp_type]
xs = points['p0']
ys = points['p1']
zs = points['x']
plot_style.pop('linestyle')
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
Expand Down Expand Up @@ -298,8 +299,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
pyplot.figure(var)
for fp_type, points in container.items():
if len(points['p']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points['p'], points[var], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(var)

Expand Down Expand Up @@ -330,10 +331,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['p0']):
plot_style = stability.plot_scheme[fp_type]
plot_style = plotstyle.plot_schema[fp_type]
xs = points['p0']
ys = points['p1']
zs = points[var]
plot_style.pop('linestyle')
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
Expand Down
22 changes: 11 additions & 11 deletions brainpy/analysis/lowdim/lowdim_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import brainpy.math as bm
from brainpy import errors, math
from brainpy.analysis import stability, constants as C, utils
from brainpy.analysis import stability, plotstyle, constants as C, utils
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None
Expand Down Expand Up @@ -107,8 +107,8 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
if with_plot:
for fp_type, points in container.items():
if len(points):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()
Expand Down Expand Up @@ -248,9 +248,9 @@ def plot_nullcline(self, with_plot=True, with_return=False,

if with_plot:
if x_style is None:
x_style = dict(color='cornflowerblue', alpha=.7, )
fmt = x_style.pop('fmt', '.')
pyplot.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline")
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple()
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")

# Nullcline of the y variable
utils.output('I am computing fy-nullcline ...')
Expand All @@ -260,9 +260,9 @@ def plot_nullcline(self, with_plot=True, with_return=False,

if with_plot:
if y_style is None:
y_style = dict(color='lightcoral', alpha=.7, )
fmt = y_style.pop('fmt', '.')
pyplot.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline")
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple()
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")

if with_plot:
pyplot.xlabel(self.x_var)
Expand Down Expand Up @@ -349,8 +349,8 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,
if with_plot:
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()
Expand Down
72 changes: 72 additions & 0 deletions brainpy/analysis/plotstyle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-


__all__ = [
'plot_schema',
'set_plot_schema',
]

from .stability import (CENTER_MANIFOLD, SADDLE_NODE, STABLE_POINT_1D,
UNSTABLE_POINT_1D, CENTER_2D, STABLE_NODE_2D,
STABLE_FOCUS_2D, STABLE_STAR_2D, STABLE_DEGENERATE_2D,
UNSTABLE_NODE_2D, UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D,
UNSTABLE_DEGENERATE_2D, UNSTABLE_LINE_2D,
STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D,
UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, STABLE_FOCUS_3D,
UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D)


_markersize = 20

plot_schema = {}

plot_schema[CENTER_MANIFOLD] = {'color': 'orangered', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
plot_schema[SADDLE_NODE] = {"color": 'tab:blue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}

plot_schema[STABLE_POINT_1D] = {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
plot_schema[UNSTABLE_POINT_1D] = {"color": 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}

plot_schema.update({
CENTER_2D: {'color': 'lime', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_NODE_2D: {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_FOCUS_2D: {"color": 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_STAR_2D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_DEGENERATE_2D: {'color': 'blueviolet', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_NODE_2D: {"color": 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_STAR_2D: {'color': 'green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_LINE_2D: {'color': 'dodgerblue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
})


plot_schema.update({
STABLE_POINT_3D: {'color': 'tab:gray', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_POINT_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_NODE_3D: {'color': 'tab:green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_SADDLE_3D: {'color': 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_FOCUS_3D: {'color': 'tab:pink', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_FOCUS_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_NODE_3D: {'color': 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_CENTER_3D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNKNOWN_3D: {'color': 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
})


def set_plot_schema(fixed_point: str, **schema):
if not isinstance(fixed_point, str):
raise TypeError(f'Must instance of string, but we got {type(fixed_point)}: {fixed_point}')
if fixed_point not in plot_schema:
raise KeyError(f'Fixed point type {fixed_point} does not found in the built-in types. ')
plot_schema[fixed_point].update(**schema)


def set_markersize(markersize):
if not isinstance(markersize, int):
raise TypeError(f"Must be an integer, but got {type(markersize)}: {markersize}")
global _markersize
__markersize = markersize
for key in tuple(plot_schema.keys()):
plot_schema[key]['markersize'] = markersize


32 changes: 3 additions & 29 deletions brainpy/analysis/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
'get_1d_stability_types',
'get_2d_stability_types',
'get_3d_stability_types',
'plot_scheme',


'stability_analysis',

Expand All @@ -27,17 +27,13 @@
'UNSTABLE_LINE_2D',
]

plot_scheme = {}


SADDLE_NODE = 'saddle node'
CENTER_MANIFOLD = 'center manifold'
plot_scheme[CENTER_MANIFOLD] = {'color': 'orangered'}
plot_scheme[SADDLE_NODE] = {"color": 'tab:blue'}

STABLE_POINT_1D = 'stable point'
UNSTABLE_POINT_1D = 'unstable point'
plot_scheme[STABLE_POINT_1D] = {"color": 'tab:red'}
plot_scheme[UNSTABLE_POINT_1D] = {"color": 'tab:olive'}

CENTER_2D = 'center'
STABLE_NODE_2D = 'stable node'
Expand All @@ -49,18 +45,7 @@
UNSTABLE_STAR_2D = 'unstable star'
UNSTABLE_DEGENERATE_2D = 'unstable degenerate'
UNSTABLE_LINE_2D = 'unstable line'
plot_scheme.update({
CENTER_2D: {'color': 'lime'},
STABLE_NODE_2D: {"color": 'tab:red'},
STABLE_FOCUS_2D: {"color": 'tab:purple'},
STABLE_STAR_2D: {'color': 'tab:olive'},
STABLE_DEGENERATE_2D: {'color': 'blueviolet'},
UNSTABLE_NODE_2D: {"color": 'tab:orange'},
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan'},
UNSTABLE_STAR_2D: {'color': 'green'},
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen'},
UNSTABLE_LINE_2D: {'color': 'dodgerblue'},
})


STABLE_POINT_3D = 'unclassified stable point'
UNSTABLE_POINT_3D = 'unclassified unstable point'
Expand All @@ -71,17 +56,6 @@
UNSTABLE_FOCUS_3D = 'unstable focus'
UNSTABLE_CENTER_3D = 'unstable center'
UNKNOWN_3D = 'unknown 3d'
plot_scheme.update({
STABLE_POINT_3D: {'color': 'tab:gray'},
UNSTABLE_POINT_3D: {'color': 'tab:purple'},
STABLE_NODE_3D: {'color': 'tab:green'},
UNSTABLE_SADDLE_3D: {'color': 'tab:red'},
UNSTABLE_FOCUS_3D: {'color': 'tab:pink'},
STABLE_FOCUS_3D: {'color': 'tab:purple'},
UNSTABLE_NODE_3D: {'color': 'tab:orange'},
UNSTABLE_CENTER_3D: {'color': 'tab:olive'},
UNKNOWN_3D: {'color': 'tab:cyan'},
})


def get_1d_stability_types():
Expand Down
24 changes: 12 additions & 12 deletions brainpy/dyn/rates/populations.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class FeedbackFHN(RateModel):
Expand Down Expand Up @@ -375,8 +375,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class QIF(RateModel):
Expand Down Expand Up @@ -558,8 +558,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class StuartLandauOscillator(RateModel):
Expand Down Expand Up @@ -700,8 +700,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class WilsonCowanModel(RateModel):
Expand Down Expand Up @@ -857,8 +857,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class JansenRitModel(RateModel):
Expand Down Expand Up @@ -976,5 +976,5 @@ def update(self, tdi, x=None):
self.i.value = bm.maximum(self.i + di * dt, 0.)

def clear_input(self):
self.Ie[:] = 0.
self.Ii[:] = 0.
self.Ie.value = bm.zeros_like(self.Ie)
self.Ii.value = bm.zeros_like(self.Ii)
2 changes: 1 addition & 1 deletion brainpy/math/delayvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)

elif self.update_method == CONCAT_UPDATING:
self.data.value = bm.concatenate([self.data[1:], bm.broadcast_to(value, self.delay_target_shape)], axis=0)
self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])])

else:
raise ValueError(f'Unknown updating method "{self.update_method}"')
Expand Down
Loading

0 comments on commit 174c81a

Please sign in to comment.