Skip to content

Commit

Permalink
add script for bug #64
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Mar 12, 2024
1 parent dc650ca commit fd8f990
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions examples/bug-64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import collections
import networkx as nx
import matplotlib.pyplot as plt
import jax.numpy as jnp
import vbjax as vb

KMTheta = collections.namedtuple(typename="KMTheta", field_names="G omega".split(" "))
km_default_theta = KMTheta(G=0.05, omega=1.0)
KMState = collections.namedtuple(typename="KMState", field_names="x".split(" "))

def km_dfun(x, c, p: KMTheta):
"Kuramoto model"
dx = p.omega + jnp.vdot(p.G, c) # or just p.G * c
return dx

def network(x, p):
weights, node_params = p
c = jnp.sum(weights * jnp.sin(x - x[:, None]), axis=1)
dx = km_dfun(x, c, node_params)
return dx


def get_ts(params, dt=0.1, T=50.0, G=0.0, sigma=0.1):
'''Run the Kuramoto model'''
omega, weights, par = params
nn = weights.shape[0]
G = jnp.ones(nn) * G
_, loop = vb.make_sde(dt, dfun=network, gfun=sigma)
par = par._replace(G=G, omega=omega)
nt = int(T / dt)
zs = vb.rand(nt, nn) * 2 * jnp.pi
xs = loop(zs[0], zs[1:], (weights, par))
ts = jnp.linspace(0, nt * dt, len(xs))
return xs, ts


nn = 3
weights = nx.to_numpy_array(nx.complete_graph(nn))
dt = 0.1

omega = jnp.abs(vb.randn(nn) * 1.0)
print('omega values are', omega)

# plt.figure(figsize=(10, 3))
for i, sigma in enumerate([0.0, 0.1, 0.2]):
xs, ts = get_ts((omega, weights, km_default_theta), dt=dt, G=0.9, sigma=sigma)
#plt.subplot(1, 3, i + 1)
#plt.plot(ts[:-1], jnp.sin(xs))
print(i, 'sigma=', sigma, jnp.sum(jnp.abs(jnp.diff(xs,axis=0))) )
# plt.show()

0 comments on commit fd8f990

Please sign in to comment.