Skip to content

Commit

Permalink
stuart landau model C++ implementation added
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziaeemehr committed Dec 22, 2023
1 parent 41ede78 commit a566774
Show file tree
Hide file tree
Showing 14 changed files with 458 additions and 41 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
*.pdf
__pycache__/
build/
_build/
.vscode/
vbi.egg-info/
html/
*/CPPModels/*.py
*/cpp/_src*.py
.ipynb_checkpoints/
latex/
html/
Expand Down
2 changes: 1 addition & 1 deletion examples/damp_oscillator_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ode = DO_cpp(parameters)
print(ode())

sol = ode.simulate()
sol = ode.run()
t = sol["t"]
x = sol["x"]

Expand Down
33 changes: 33 additions & 0 deletions examples/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from scipy import signal
from numpy.fft import fft
import matplotlib.pyplot as plt
from copy import deepcopy, copy


def fft_signal(x, t):
Expand Down Expand Up @@ -56,3 +57,35 @@ def plot_ts(data, par, ax, **kwargs):
ax[i].tick_params(labelsize=14)

plt.tight_layout()

def sl_visualize(data, params, **kwargs):

x = data['x']
t = data['t']

mosaic = """
AAB
"""

fs = 1/(params['dt']*params['record_step'])
fig = plt.figure(constrained_layout=True, figsize=(12, 3))
ax = fig.subplot_mosaic(mosaic)

x_avg = np.mean(x, axis=0)
ax['A'].plot(t, x_avg.T, label="x", **kwargs)
ax['A'].set_ylabel(r"$\sum$ Real $Z$", fontsize=16)

freq, pxx = signal.welch(x, fs=fs, nperseg=4096)
# pxx /= np.max(pxx)
pxx_avg = np.average(pxx, axis=0)
ax['B'].plot(freq, pxx_avg, **kwargs)
ax['B'].set_xlabel("Frequency [Hz]", fontsize=16)
ax['B'].set_ylabel("Power", fontsize=16)
ax['B'].set_xlim(0, 60)
ti = params['t_transition']
tf = params['t_end']
ax['A'].set_xlim(tf-2, tf)
ax['A'].set_xlabel("Time [s]", fontsize=16)

idx = np.argmax(pxx_avg)
print(f"fmax = {freq[idx]} Hz, Pxx = {pxx_avg[idx]}")
2 changes: 1 addition & 1 deletion examples/jansen_rit_sdde_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


obj = JR_sdde_cpp(param)
data = obj.simulate()
data = obj.run()
t = data['t']
x = data['x']
sti = data['sti']
Expand Down
2 changes: 1 addition & 1 deletion examples/jansen_rit_sde_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

obj = JR_sde_cpp(par_dict)
print(obj())
data = obj.simulate(control_dict)
data = obj.run(control_dict)

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
plot_ts_pxx(data, par_dict, [ax[0], ax[1]], alpha=0.6, lw=1)
Expand Down
21 changes: 19 additions & 2 deletions examples/mpr_sde_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

obj = MPR_sde(parameters)
# print(obj())
sol = obj.simulate(par=control_dict)
sol = obj.run(par=control_dict)
print(obj.eta)

t = sol["t"]
Expand All @@ -53,10 +53,27 @@
print(f"x.shape = {x.shape}")

if x.ndim == 2:
pass
fig, ax = plt.subplots(1, figsize=(10, 3))
ax.set_xlabel("Time [s]")
ax.set_ylabel("BOLD")
plt.plot(t/1000, x.T, alpha=0.8, lw=2)
plt.margins(0,0.1)
plt.tight_layout()
plt.show()
plt.savefig("output/mpr_sde_ts.png", dpi=300)
plt.close()
else:
exit(0)


# Feature extraction ------------------------------------------------
from vbi.feature_extraction.features_settings import *
from vbi.feature_extraction.calc_features import *

fs = 1/(parameters["dt_bold"]) / 1000
cfg = get_features_by_domain(domain="statistical")
# report_cfg(cfg)
data = dataframe_feature_extractor([x], fs, fea_dict=cfg, n_workers=1)
print(data.values)


42 changes: 42 additions & 0 deletions examples/sl_sdde_ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import vbi
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from helpers import sl_visualize
from vbi.models.cpp.sl import SL_sdde

seed = 2
np.random.seed(seed)

LABESSIZE = 14
plt.rcParams['axes.labelsize'] = LABESSIZE
plt.rcParams['xtick.labelsize'] = LABESSIZE
plt.rcParams['ytick.labelsize'] = LABESSIZE

ds = vbi.Load_Data_Sample(nn=84)
weights = ds.get_weights(normalize=True)
distances = ds.get_lengths() / 1000 # [m]
nn = weights.shape[0]

params = {
"G": 1000.0,
"a": -5.0,
"dt": 1e-4,
"weights": weights,
"distances": distances, # [m]
"velocity": 6.0, # [m/s]
"omega": 40*2*np.pi * np.ones(nn),
'sigma_r': 1e-4, # noise strength
'sigma_v': 1e-4, # noise strength
"record_step": 2,
"t_initial": 0.0,
"t_transition": 2.0,
"t_end": 10.0,
"seed": 2,
}

obj = SL_sdde(params)
data = obj.run()

sl_visualize(data, params, color="k")
plt.show()
4 changes: 3 additions & 1 deletion vbi/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .tests.all_tests import tests
from .tests.all_tests import tests

from .feature_extraction.utility import Load_Data_Sample
Loading

0 comments on commit a566774

Please sign in to comment.