diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py index 9b74be3bf..b99d326d6 100644 --- a/brainpy/analysis/highdim/slow_points.py +++ b/brainpy/analysis/highdim/slow_points.py @@ -355,8 +355,7 @@ def train(idx): return loss def batch_train(start_i, n_batch): - f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True) - return f(bm.arange(start_i, start_i + n_batch)) + return bm.for_loop(train, dyn_vars, bm.arange(start_i, start_i + n_batch)) # Run the optimization if self.verbose: @@ -369,7 +368,7 @@ def batch_train(start_i, n_batch): break batch_idx_start = oidx * num_batch start_time = time.time() - (_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch) + train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch) batch_time = time.time() - start_time opt_losses.append(train_losses) @@ -722,8 +721,6 @@ def _generate_ds_cell_function( shared = DotDict(t=t, dt=dt, i=0) def f_cell(h: Dict): - target.clear_input() - # update target variables for k, v in self.target_vars.items(): v.value = (bm.asarray(h[k], dtype=v.dtype) @@ -735,6 +732,7 @@ def f_cell(h: Dict): v.value = self.excluded_data[k] # add inputs + target.clear_input() if f_input is not None: f_input(shared) @@ -743,7 +741,7 @@ def f_cell(h: Dict): target.update(*args) # get new states - new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis)) + new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis)) for k, v in self.target_vars.items()} return new_h diff --git a/brainpy/analysis/lowdim/lowdim_bifurcation.py b/brainpy/analysis/lowdim/lowdim_bifurcation.py index 58ac84694..3ac4b8487 100644 --- a/brainpy/analysis/lowdim/lowdim_bifurcation.py +++ b/brainpy/analysis/lowdim/lowdim_bifurcation.py @@ -354,8 +354,17 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, if with_return: return final_fps, final_pars, jacobians - def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False, - plot_style=None, tol=0.001, show=False, dt=None, offset=1.): + def plot_limit_cycle_by_sim( + self, + duration=100, + with_plot: bool = True, + with_return: bool = False, + plot_style: dict = None, + tol: float = 0.001, + show: bool = False, + dt: float = None, + offset: float = 1. + ): global pyplot if pyplot is None: from matplotlib import pyplot utils.output('I am plotting the limit cycle ...') @@ -400,10 +409,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals if len(ps_limit_cycle[0]): for i, var in enumerate(self.target_var_names): pyplot.figure(var) - pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'], - **plot_style, label='limit cycle (max)') - pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'], - **plot_style, label='limit cycle (min)') + pyplot.plot(ps_limit_cycle[0], + ps_limit_cycle[1], + vs_limit_cycle[i]['max'], + **plot_style, + label='limit cycle (max)') + pyplot.plot(ps_limit_cycle[0], + ps_limit_cycle[1], + vs_limit_cycle[i]['min'], + **plot_style, + label='limit cycle (min)') pyplot.legend() elif len(self.target_par_names) == 1: @@ -427,8 +442,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals class FastSlow1D(Bifurcation1D): - def __init__(self, model, fast_vars, slow_vars, fixed_vars=None, - pars_update=None, resolutions=None, options=None): + def __init__( + self, + model, + fast_vars: dict, + slow_vars: dict, + fixed_vars: dict = None, + pars_update: dict = None, + resolutions=None, + options: dict = None + ): super(FastSlow1D, self).__init__(model=model, target_pars=slow_vars, target_vars=fast_vars, @@ -510,8 +533,16 @@ def plot_trajectory(self, initials, duration, plot_durations=None, class FastSlow2D(Bifurcation2D): - def __init__(self, model, fast_vars, slow_vars, fixed_vars=None, - pars_update=None, resolutions=0.1, options=None): + def __init__( + self, + model, + fast_vars: dict, + slow_vars: dict, + fixed_vars: dict = None, + pars_update: dict = None, + resolutions=0.1, + options: dict = None + ): super(FastSlow2D, self).__init__(model=model, target_pars=slow_vars, target_vars=fast_vars, diff --git a/brainpy/analysis/utils/model.py b/brainpy/analysis/utils/model.py index 2d0ab6835..c70905d45 100644 --- a/brainpy/analysis/utils/model.py +++ b/brainpy/analysis/utils/model.py @@ -112,14 +112,14 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None): # variables assert isinstance(initial_vars, dict) - initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=jnp.float_)) + initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=bm.dftype())) for k, v in initial_vars.items()} self.register_implicit_vars(initial_vars) # parameters pars = dict() if pars is None else pars assert isinstance(pars, dict) - self.pars = [jnp.asarray(bm.as_device_array(v), dtype=jnp.float_) + self.pars = [jnp.asarray(bm.as_device_array(v), dtype=bm.dftype()) for k, v in pars.items()] # integrals @@ -128,7 +128,8 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None): # runner self.runner = DSRunner(self, monitors=list(initial_vars.keys()), - dyn_vars=self.vars().unique(), dt=dt, + dyn_vars=self.vars().unique(), + dt=dt, progress_bar=False) def update(self, sha): diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 5e6cc160c..27e31cab5 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -343,8 +343,7 @@ def offline_fit(self, raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.') def clear_input(self): - for node in self.nodes(level=1, include_self=False).subset(NeuGroup).unique().values(): - node.clear_input() + pass class Container(DynamicalSystem): @@ -430,6 +429,10 @@ def __getattr__(self, item): else: return super(Container, self).__getattribute__(item) + def clear_input(self): + for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values(): + node.clear_input() + class Sequential(Container): def __init__( diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index 86c7a7b2e..a691e54f7 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -244,20 +244,17 @@ def __init__( # variables self.V = variable(self._V_initializer, mode, self.varshape) - if self._m_initializer is None: - self.m = bm.Variable(self.m_inf(self.V.value)) - else: - self.m = variable(self._m_initializer, mode, self.varshape) - if self._h_initializer is None: - self.h = bm.Variable(self.h_inf(self.V.value)) - else: - self.h = variable(self._h_initializer, mode, self.varshape) - if self._n_initializer is None: - self.n = bm.Variable(self.n_inf(self.V.value)) - else: - self.n = variable(self._n_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.m = (bm.Variable(self.m_inf(self.V.value)) + if m_initializer is None else + variable(self._m_initializer, mode, self.varshape)) + self.h = (bm.Variable(self.h_inf(self.V.value)) + if h_initializer is None else + variable(self._h_initializer, mode, self.varshape)) + self.n = (bm.Variable(self.n_inf(self.V.value)) + if n_initializer is None else + variable(self._n_initializer, mode, self.varshape)) self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.input = variable(bm.zeros, mode, self.varshape) # integral if self.noise is None: @@ -309,7 +306,7 @@ def dV(self, V, t, m, h, n, I_ext): @property def derivative(self): - return JointEq([self.dV, self.dm, self.dh, self.dn]) + return JointEq(self.dV, self.dm, self.dh, self.dn) def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py index fbb1813cf..b9c81fde3 100644 --- a/brainpy/dyn/runners.py +++ b/brainpy/dyn/runners.py @@ -566,8 +566,7 @@ def f_predict(self, shared_args: Dict = None): monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args) - def _step_func(inputs): - t, i, x = inputs + def _step_func(t, i, x): self.target.clear_input() # input step shared = DotDict(t=t, i=i, dt=self.dt) @@ -586,8 +585,7 @@ def _step_func(inputs): if self.jit['predict']: dyn_vars = self.target.vars() dyn_vars.update(self.dyn_vars) - f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True) - run_func = lambda all_inputs: f(all_inputs)[1] + run_func = lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs) else: def run_func(xs): @@ -601,7 +599,7 @@ def run_func(xs): x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray)) # step at the i - output, mon = _step_func((times[i], indices[i], x)) + output, mon = _step_func(times[i], indices[i], x) # append output and monitor outputs.append(output) diff --git a/brainpy/inputs/currents.py b/brainpy/inputs/currents.py index b6ada2712..3d0eb9cfe 100644 --- a/brainpy/inputs/currents.py +++ b/brainpy/inputs/currents.py @@ -307,9 +307,9 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, def _f(t): x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.rand(n) + return x.value - f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x) - noises = f(jnp.arange(t_start, t_end, dt)) + noises = bm.for_loop(_f, [x, rng], jnp.arange(t_start, t_end, dt)) t_end = duration if t_end is None else t_end i_start = int(t_start / dt) diff --git a/brainpy/integrators/ode/tests/test_ode_method_adaptive_rk.py b/brainpy/integrators/ode/tests/test_ode_method_adaptive_rk.py index 671f9df85..4de18c2f3 100644 --- a/brainpy/integrators/ode/tests/test_ode_method_adaptive_rk.py +++ b/brainpy/integrators/ode/tests/test_ode_method_adaptive_rk.py @@ -45,7 +45,7 @@ def f(t): if show: fig = plt.figure() - ax = fig.gca(projection='3d') + ax = fig.add_subplot(111, projection='3d') plt.plot(mon_x, mon_y, mon_z) ax.set_xlabel('x') ax.set_xlabel('y') diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index ea79b2d65..b59faf0c8 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -217,16 +217,12 @@ def __init__( # build the update step if self.jit['predict']: - _loop_func = bm.make_loop( - self._step, - dyn_vars=self.dyn_vars, - out_vars={k: self.variables[k] for k in self.monitors.keys()}, - has_return=True - ) + def _loop_func(times): + return bm.for_loop(self._step, self.dyn_vars, times) else: def _loop_func(times): - out_vars = {k: [] for k in self.monitors.keys()} returns = {k: [] for k in self.fun_monitors.keys()} + returns.update({k: [] for k in self.monitors.keys()}) for i in range(len(times)): _t = times[i] _dt = self.dt @@ -237,9 +233,9 @@ def _loop_func(times): self._step(_t) # variable monitors for k in self.monitors.keys(): - out_vars[k].append(bm.as_device_array(self.variables[k])) - out_vars = {k: bm.asarray(out_vars[k]) for k in self.monitors.keys()} - return out_vars, returns + returns[k].append(bm.as_device_array(self.variables[k])) + returns = {k: bm.asarray(returns[k]) for k in returns.keys()} + return returns self.step_func = _loop_func def _step(self, t): @@ -252,11 +248,6 @@ def _step(self, t): kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()}) self.idx += 1 - # return of function monitors - returns = dict() - for key, func in self.fun_monitors.items(): - returns[key] = func(t, self.dt) - # call integrator function update_values = self.target(**kwargs) if len(self.target.variables) == 1: @@ -268,6 +259,13 @@ def _step(self, t): # progress bar if self.progress_bar: id_tap(lambda *args: self._pbar.update(), ()) + + # return of function monitors + returns = dict() + for key, func in self.fun_monitors.items(): + returns[key] = func(t, self.dt) + for k in self.monitors.keys(): + returns[k] = self.variables[k].value return returns def run(self, duration, start_t=None, eval_time=False): @@ -302,14 +300,13 @@ def run(self, duration, start_t=None, eval_time=False): refresh=True) if eval_time: t0 = time.time() - hists, returns = self.step_func(times) + hists = self.step_func(times) if eval_time: running_time = time.time() - t0 if self.progress_bar: self._pbar.close() # post-running - hists.update(returns) times += self.dt if self.numpy_mon_after_run: times = np.asarray(times) diff --git a/brainpy/integrators/sde/tests/test_sde_scalar.py b/brainpy/integrators/sde/tests/test_sde_scalar.py index 19465cfb4..6f9fae51a 100644 --- a/brainpy/integrators/sde/tests/test_sde_scalar.py +++ b/brainpy/integrators/sde/tests/test_sde_scalar.py @@ -50,7 +50,7 @@ def lorenz_system(method, **kwargs): mon3 = bp.math.array(mon3).to_numpy() fig = plt.figure() - ax = fig.gca(projection='3d') + ax = fig.add_subplot(111, projection='3d') plt.plot(mon1, mon2, mon3) ax.set_xlabel('x') ax.set_xlabel('y') diff --git a/brainpy/math/operators/utils.py b/brainpy/math/operators/utils.py index bae41f232..fb8c1cfa4 100644 --- a/brainpy/math/operators/utils.py +++ b/brainpy/math/operators/utils.py @@ -17,12 +17,14 @@ def _check_brainpylib(ops_name): raise PackageMissingError( f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n' f'Please install it through:\n\n' - f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION} -U' + f'>>> pip install brainpylib=={_BRAINPYLIB_MINIMAL_VERSION}\n' + f'>>> # or \n' + f'>>> pip install brainpylib -U' ) else: raise PackageMissingError( f'"brainpylib" must be installed when the user ' f'wants to use "{ops_name}" operator. \n' f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n' - f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}' + f'>>> pip install brainpylib' ) diff --git a/brainpy/measure/correlation.py b/brainpy/measure/correlation.py index 5e4840860..c742fa05b 100644 --- a/brainpy/measure/correlation.py +++ b/brainpy/measure/correlation.py @@ -17,7 +17,7 @@ ] -@jit +# @jit @partial(vmap, in_axes=(None, 0, 0)) def _cc(states, i, j): sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) @@ -86,7 +86,7 @@ def _var(neu_signal, i): return jnp.mean(neu_signal * neu_signal) - jnp.mean(neu_signal) ** 2 -@jit +# @jit def voltage_fluctuation(potentials): r"""Calculate neuronal synchronization via voltage variance. @@ -202,7 +202,7 @@ def functional_connectivity(activities): return np.nan_to_num(fc) -@jit +# @jit def functional_connectivity_dynamics(activities, window_size=30, step_size=5): """Computes functional connectivity dynamics (FCD) matrix. @@ -233,7 +233,7 @@ def _weighted_cov(x, y, w): return jnp.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / jnp.sum(w) -@jit +# @jit def weighted_correlation(x, y, w): """Weighted Pearson correlation of two data series. diff --git a/brainpy/measure/firings.py b/brainpy/measure/firings.py index e01c3f499..0f335e8b9 100644 --- a/brainpy/measure/firings.py +++ b/brainpy/measure/firings.py @@ -36,7 +36,7 @@ def raster_plot(sp_matrix, times): return index, time -@jit +# @jit def _firing_rate(sp_matrix, window): sp_matrix = bm.as_device_array(sp_matrix) rate = jnp.sum(sp_matrix, axis=1) / sp_matrix.shape[1] diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index 2657619d7..ceb99bf19 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -520,7 +520,7 @@ def f_train(self, shared_args=None) -> Callable: shared_args_str = serialize_kwargs(shared_args) if shared_args_str not in self._f_train_compiled: - def train_step(x): + def train_step(*x): # t, i, input_, target_ = x res = self.f_grad(shared_args)(*x) self.optimizer.update(res[0]) @@ -529,8 +529,7 @@ def train_step(x): if self.jit[c.FIT_PHASE]: dyn_vars = self.target.vars() dyn_vars.update(self.dyn_vars) - f = bm.make_loop(train_step, dyn_vars=dyn_vars.unique(), has_return=True) - run_func = lambda all_inputs: f(all_inputs)[1] + run_func = lambda all_inputs: bm.for_loop(train_step, dyn_vars.unique(), all_inputs) else: def run_func(xs): @@ -541,7 +540,7 @@ def run_func(xs): x = tree_map(lambda x: x[i], inputs, is_leaf=_is_jax_array) y = tree_map(lambda x: x[i], targets, is_leaf=_is_jax_array) # step at the i - loss = train_step((times[i], indices[i], x, y)) + loss = train_step(times[i], indices[i], x, y) # append output and monitor losses.append(loss) return bm.asarray(losses) diff --git a/brainpy/train/online.py b/brainpy/train/online.py index f0efbed76..4045511cb 100644 --- a/brainpy/train/online.py +++ b/brainpy/train/online.py @@ -234,8 +234,7 @@ def _make_fit_func(self, shared_args: Dict): monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args) - def _step_func(all_inputs): - t, i, x, ys = all_inputs + def _step_func(t, i, x, ys): shared = DotDict(t=t, dt=self.dt, i=i) # input step @@ -262,8 +261,7 @@ def _step_func(all_inputs): if self.jit['fit']: dyn_vars = self.target.vars() dyn_vars.update(self.dyn_vars) - f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True) - return lambda all_inputs: f(all_inputs)[1] + return lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs) else: def run_func(all_inputs): @@ -273,7 +271,7 @@ def run_func(all_inputs): for i in range(times.shape[0]): x = tree_map(lambda x: x[i], xs) y = tree_map(lambda x: x[i], ys) - output, mon = _step_func((times[i], indices[i], x, y)) + output, mon = _step_func(times[i], indices[i], x, y) outputs.append(output) for key, value in mon.items(): monitors[key].append(value) diff --git a/examples/analysis/4d_HH_model.py b/examples/analysis/4d_HH_model.py index c4c4720d4..fc7b54af4 100644 --- a/examples/analysis/4d_HH_model.py +++ b/examples/analysis/4d_HH_model.py @@ -3,13 +3,15 @@ import brainpy as bp import brainpy.math as bm + I = 5. model = bp.dyn.neurons.HH(1) -runner = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V']) +runner = bp.dyn.DSRunner(model, inputs=(model.input, I), monitors=['V']) runner.run(100) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) # analysis +bm.enable_x64() model = bp.dyn.neurons.HH(1, method='euler') finder = bp.analysis.SlowPointFinder( model, @@ -18,13 +20,14 @@ 'm': model.m, 'h': model.h, 'n': model.n}, - dt=1. + dt=10. ) -candidates = {'V': bm.random.normal(0., 5., (1000, model.num)) - 50., +finder.find_fps_with_opt_solver( + candidates={'V': bm.random.normal(0., 10., (1000, model.num)) - 50., 'm': bm.random.random((1000, model.num)), 'h': bm.random.random((1000, model.num)), 'n': bm.random.random((1000, model.num))} -finder.find_fps_with_opt_solver(candidates=candidates) +) finder.filter_loss(1e-7) finder.keep_unique(tolerance=1e-1) print('fixed_points: ', finder.fixed_points) diff --git a/examples/training/Song_2016_EI_RNN.py b/examples/training/Song_2016_EI_RNN.py index 1c27883ce..1a273ea8c 100644 --- a/examples/training/Song_2016_EI_RNN.py +++ b/examples/training/Song_2016_EI_RNN.py @@ -128,17 +128,17 @@ def __init__(self, num_input, num_hidden, num_output, num_batch, self.mask = bm.asarray(mask, dtype=bm.dftype()) # input weight - self.w_ir = bm.TrainVar(w_ir(num_input, num_hidden)) + self.w_ir = bm.TrainVar(w_ir((num_input, num_hidden))) # recurrent weight bound = 1 / num_hidden ** 0.5 - self.w_rr = bm.TrainVar(w_rr(num_hidden, num_hidden)) + self.w_rr = bm.TrainVar(w_rr((num_hidden, num_hidden))) self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size) self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden)) # readout weight bound = 1 / self.e_size ** 0.5 - self.w_ro = bm.TrainVar(w_ro(self.e_size, num_output)) + self.w_ro = bm.TrainVar(w_ro((self.e_size, num_output))) self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output)) # variables @@ -158,15 +158,13 @@ def make_update(self, h: bm.JaxArray, o: bm.JaxArray): def f(x): h.value = self.cell(x, h.value) o.value = self.readout(h.value[:, :self.e_size]) + return h.value, o.value return f def predict(self, xs): self.h[:] = 0. - f = bm.make_loop(self.make_update(self.h, self.o), - dyn_vars=self.vars(), - out_vars=[self.h, self.o]) - return f(xs) + return bm.for_loop(self.make_update(self.h, self.o), self.vars(), xs) def loss(self, xs, ys): hs, os = self.predict(xs) @@ -247,7 +245,7 @@ def train(xs, ys): rnn_activity, action_pred = predict(inputs) # Compute performance - action_pred = action_pred.numpy() + action_pred = bm.as_numpy(action_pred) choice = np.argmax(action_pred[-1, 0, :]) correct = choice == gt[-1] @@ -257,7 +255,7 @@ def train(xs, ys): trial_infos[i] = trial_info # Log stimulus period activity - rnn_activity = rnn_activity.numpy()[:, 0, :] + rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :] activity_dict[i] = rnn_activity # Compute stimulus selectivity for all units @@ -312,7 +310,7 @@ def train(xs, ys): plt.show() # %% -W = (bm.abs(net.w_rr) * net.mask).numpy() +W = bm.as_numpy(bm.abs(net.w_rr) * net.mask) # Sort by selectivity W = W[:, ind_sort][ind_sort, :] wlim = np.max(np.abs(W))