Skip to content

Commit

Permalink
fix some bugs (#262)
Browse files Browse the repository at this point in the history
fix some bugs
  • Loading branch information
chaoming0625 authored Sep 28, 2022
2 parents 64d1e21 + c177742 commit e25144e
Show file tree
Hide file tree
Showing 17 changed files with 116 additions and 91 deletions.
10 changes: 4 additions & 6 deletions brainpy/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,7 @@ def train(idx):
return loss

def batch_train(start_i, n_batch):
f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True)
return f(bm.arange(start_i, start_i + n_batch))
return bm.for_loop(train, dyn_vars, bm.arange(start_i, start_i + n_batch))

# Run the optimization
if self.verbose:
Expand All @@ -369,7 +368,7 @@ def batch_train(start_i, n_batch):
break
batch_idx_start = oidx * num_batch
start_time = time.time()
(_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch)
train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch)
batch_time = time.time() - start_time
opt_losses.append(train_losses)

Expand Down Expand Up @@ -722,8 +721,6 @@ def _generate_ds_cell_function(
shared = DotDict(t=t, dt=dt, i=0)

def f_cell(h: Dict):
target.clear_input()

# update target variables
for k, v in self.target_vars.items():
v.value = (bm.asarray(h[k], dtype=v.dtype)
Expand All @@ -735,6 +732,7 @@ def f_cell(h: Dict):
v.value = self.excluded_data[k]

# add inputs
target.clear_input()
if f_input is not None:
f_input(shared)

Expand All @@ -743,7 +741,7 @@ def f_cell(h: Dict):
target.update(*args)

# get new states
new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis))
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
for k, v in self.target_vars.items()}
return new_h

Expand Down
51 changes: 41 additions & 10 deletions brainpy/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,17 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
if with_return:
return final_fps, final_pars, jacobians

def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
def plot_limit_cycle_by_sim(
self,
duration=100,
with_plot: bool = True,
with_return: bool = False,
plot_style: dict = None,
tol: float = 0.001,
show: bool = False,
dt: float = None,
offset: float = 1.
):
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the limit cycle ...')
Expand Down Expand Up @@ -400,10 +409,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
if len(ps_limit_cycle[0]):
for i, var in enumerate(self.target_var_names):
pyplot.figure(var)
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
**plot_style, label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
**plot_style, label='limit cycle (min)')
pyplot.plot(ps_limit_cycle[0],
ps_limit_cycle[1],
vs_limit_cycle[i]['max'],
**plot_style,
label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0],
ps_limit_cycle[1],
vs_limit_cycle[i]['min'],
**plot_style,
label='limit cycle (min)')
pyplot.legend()

elif len(self.target_par_names) == 1:
Expand All @@ -427,8 +442,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals


class FastSlow1D(Bifurcation1D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, resolutions=None, options=None):
def __init__(
self,
model,
fast_vars: dict,
slow_vars: dict,
fixed_vars: dict = None,
pars_update: dict = None,
resolutions=None,
options: dict = None
):
super(FastSlow1D, self).__init__(model=model,
target_pars=slow_vars,
target_vars=fast_vars,
Expand Down Expand Up @@ -510,8 +533,16 @@ def plot_trajectory(self, initials, duration, plot_durations=None,


class FastSlow2D(Bifurcation2D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, resolutions=0.1, options=None):
def __init__(
self,
model,
fast_vars: dict,
slow_vars: dict,
fixed_vars: dict = None,
pars_update: dict = None,
resolutions=0.1,
options: dict = None
):
super(FastSlow2D, self).__init__(model=model,
target_pars=slow_vars,
target_vars=fast_vars,
Expand Down
7 changes: 4 additions & 3 deletions brainpy/analysis/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):

# variables
assert isinstance(initial_vars, dict)
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=jnp.float_))
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=bm.dftype()))
for k, v in initial_vars.items()}
self.register_implicit_vars(initial_vars)

# parameters
pars = dict() if pars is None else pars
assert isinstance(pars, dict)
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=jnp.float_)
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=bm.dftype())
for k, v in pars.items()]

# integrals
Expand All @@ -128,7 +128,8 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
# runner
self.runner = DSRunner(self,
monitors=list(initial_vars.keys()),
dyn_vars=self.vars().unique(), dt=dt,
dyn_vars=self.vars().unique(),
dt=dt,
progress_bar=False)

def update(self, sha):
Expand Down
7 changes: 5 additions & 2 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,7 @@ def offline_fit(self,
raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.')

def clear_input(self):
for node in self.nodes(level=1, include_self=False).subset(NeuGroup).unique().values():
node.clear_input()
pass


class Container(DynamicalSystem):
Expand Down Expand Up @@ -430,6 +429,10 @@ def __getattr__(self, item):
else:
return super(Container, self).__getattribute__(item)

def clear_input(self):
for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
node.clear_input()


class Sequential(Container):
def __init__(
Expand Down
25 changes: 11 additions & 14 deletions brainpy/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,20 +244,17 @@ def __init__(

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
if self._m_initializer is None:
self.m = bm.Variable(self.m_inf(self.V.value))
else:
self.m = variable(self._m_initializer, mode, self.varshape)
if self._h_initializer is None:
self.h = bm.Variable(self.h_inf(self.V.value))
else:
self.h = variable(self._h_initializer, mode, self.varshape)
if self._n_initializer is None:
self.n = bm.Variable(self.n_inf(self.V.value))
else:
self.n = variable(self._n_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.m = (bm.Variable(self.m_inf(self.V.value))
if m_initializer is None else
variable(self._m_initializer, mode, self.varshape))
self.h = (bm.Variable(self.h_inf(self.V.value))
if h_initializer is None else
variable(self._h_initializer, mode, self.varshape))
self.n = (bm.Variable(self.n_inf(self.V.value))
if n_initializer is None else
variable(self._n_initializer, mode, self.varshape))
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)

# integral
if self.noise is None:
Expand Down Expand Up @@ -309,7 +306,7 @@ def dV(self, V, t, m, h, n, I_ext):

@property
def derivative(self):
return JointEq([self.dV, self.dm, self.dh, self.dn])
return JointEq(self.dV, self.dm, self.dh, self.dn)

def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
Expand Down
8 changes: 3 additions & 5 deletions brainpy/dyn/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ def f_predict(self, shared_args: Dict = None):

monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)

def _step_func(inputs):
t, i, x = inputs
def _step_func(t, i, x):
self.target.clear_input()
# input step
shared = DotDict(t=t, i=i, dt=self.dt)
Expand All @@ -586,8 +585,7 @@ def _step_func(inputs):
if self.jit['predict']:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
run_func = lambda all_inputs: f(all_inputs)[1]
run_func = lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)

else:
def run_func(xs):
Expand All @@ -601,7 +599,7 @@ def run_func(xs):
x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray))

# step at the i
output, mon = _step_func((times[i], indices[i], x))
output, mon = _step_func(times[i], indices[i], x)

# append output and monitor
outputs.append(output)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/inputs/currents.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None,

def _f(t):
x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.rand(n)
return x.value

f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x)
noises = f(jnp.arange(t_start, t_end, dt))
noises = bm.for_loop(_f, [x, rng], jnp.arange(t_start, t_end, dt))

t_end = duration if t_end is None else t_end
i_start = int(t_start / dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def f(t):

if show:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax = fig.add_subplot(111, projection='3d')
plt.plot(mon_x, mon_y, mon_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
Expand Down
31 changes: 14 additions & 17 deletions brainpy/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,12 @@ def __init__(

# build the update step
if self.jit['predict']:
_loop_func = bm.make_loop(
self._step,
dyn_vars=self.dyn_vars,
out_vars={k: self.variables[k] for k in self.monitors.keys()},
has_return=True
)
def _loop_func(times):
return bm.for_loop(self._step, self.dyn_vars, times)
else:
def _loop_func(times):
out_vars = {k: [] for k in self.monitors.keys()}
returns = {k: [] for k in self.fun_monitors.keys()}
returns.update({k: [] for k in self.monitors.keys()})
for i in range(len(times)):
_t = times[i]
_dt = self.dt
Expand All @@ -237,9 +233,9 @@ def _loop_func(times):
self._step(_t)
# variable monitors
for k in self.monitors.keys():
out_vars[k].append(bm.as_device_array(self.variables[k]))
out_vars = {k: bm.asarray(out_vars[k]) for k in self.monitors.keys()}
return out_vars, returns
returns[k].append(bm.as_device_array(self.variables[k]))
returns = {k: bm.asarray(returns[k]) for k in returns.keys()}
return returns
self.step_func = _loop_func

def _step(self, t):
Expand All @@ -252,11 +248,6 @@ def _step(self, t):
kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()})
self.idx += 1

# return of function monitors
returns = dict()
for key, func in self.fun_monitors.items():
returns[key] = func(t, self.dt)

# call integrator function
update_values = self.target(**kwargs)
if len(self.target.variables) == 1:
Expand All @@ -268,6 +259,13 @@ def _step(self, t):
# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())

# return of function monitors
returns = dict()
for key, func in self.fun_monitors.items():
returns[key] = func(t, self.dt)
for k in self.monitors.keys():
returns[k] = self.variables[k].value
return returns

def run(self, duration, start_t=None, eval_time=False):
Expand Down Expand Up @@ -302,14 +300,13 @@ def run(self, duration, start_t=None, eval_time=False):
refresh=True)
if eval_time:
t0 = time.time()
hists, returns = self.step_func(times)
hists = self.step_func(times)
if eval_time:
running_time = time.time() - t0
if self.progress_bar:
self._pbar.close()

# post-running
hists.update(returns)
times += self.dt
if self.numpy_mon_after_run:
times = np.asarray(times)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/integrators/sde/tests/test_sde_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def lorenz_system(method, **kwargs):
mon3 = bp.math.array(mon3).to_numpy()

fig = plt.figure()
ax = fig.gca(projection='3d')
ax = fig.add_subplot(111, projection='3d')
plt.plot(mon1, mon2, mon3)
ax.set_xlabel('x')
ax.set_xlabel('y')
Expand Down
6 changes: 4 additions & 2 deletions brainpy/math/operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def _check_brainpylib(ops_name):
raise PackageMissingError(
f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n'
f'Please install it through:\n\n'
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION} -U'
f'>>> pip install brainpylib=={_BRAINPYLIB_MINIMAL_VERSION}\n'
f'>>> # or \n'
f'>>> pip install brainpylib -U'
)
else:
raise PackageMissingError(
f'"brainpylib" must be installed when the user '
f'wants to use "{ops_name}" operator. \n'
f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n'
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}'
f'>>> pip install brainpylib'
)
8 changes: 4 additions & 4 deletions brainpy/measure/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
]


@jit
# @jit
@partial(vmap, in_axes=(None, 0, 0))
def _cc(states, i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
Expand Down Expand Up @@ -86,7 +86,7 @@ def _var(neu_signal, i):
return jnp.mean(neu_signal * neu_signal) - jnp.mean(neu_signal) ** 2


@jit
# @jit
def voltage_fluctuation(potentials):
r"""Calculate neuronal synchronization via voltage variance.
Expand Down Expand Up @@ -202,7 +202,7 @@ def functional_connectivity(activities):
return np.nan_to_num(fc)


@jit
# @jit
def functional_connectivity_dynamics(activities, window_size=30, step_size=5):
"""Computes functional connectivity dynamics (FCD) matrix.
Expand Down Expand Up @@ -233,7 +233,7 @@ def _weighted_cov(x, y, w):
return jnp.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / jnp.sum(w)


@jit
# @jit
def weighted_correlation(x, y, w):
"""Weighted Pearson correlation of two data series.
Expand Down
Loading

0 comments on commit e25144e

Please sign in to comment.