6
6
import numpy as np
7
7
import scipy
8
8
from escnn .group import Representation
9
+ from tqdm import tqdm
9
10
10
11
from data .DynamicsRecording import DynamicsRecording
11
12
from utils .mysc import companion_matrix , matrix_average_trick , random_orthogonal_matrix
@@ -26,6 +27,7 @@ def sample_initial_condition(state_dim, P=None, z=None):
26
27
distance_from_origin = max (distance_from_origin , MIN_DISTANCE_FROM_ORIGIN ) # Truncate unlikely low values
27
28
x0 = distance_from_origin * direction
28
29
30
+ trials = 500
29
31
if P is not None :
30
32
violation = P @ x0 < z
31
33
is_constraint_violated = np .any (violation )
@@ -42,6 +44,10 @@ def sample_initial_condition(state_dim, P=None, z=None):
42
44
43
45
violation = P @ x0 < z
44
46
is_constraint_violated = np .any (violation )
47
+
48
+ trials -= 1
49
+ if trials == 0 :
50
+ raise RuntimeError ("Too constrained." )
45
51
if np .linalg .norm (x0 ) < MIN_DISTANCE_FROM_ORIGIN : # If sample is too close to zero ignore it.
46
52
x0 = sample_initial_condition (state_dim , P = P , z = z )
47
53
return x0
@@ -152,7 +158,7 @@ def stable_equivariant_lin_dynamics(rep_X: Representation, time_constant=1, min_
152
158
iso_state_dim = rep_iso .size
153
159
A_iso = stable_lin_dynamics (rep_iso ,
154
160
time_constant = time_constant ,
155
- stable_eigval_prob = 1 / (iso_state_dim ) if state_dim > 1 else 0.0 ,
161
+ stable_eigval_prob = 1 / (iso_state_dim + 1 ) if state_dim > 1 else 0.0 ,
156
162
min_period = min_period ,
157
163
max_period = max_period )
158
164
# Enforce G-equivariance
@@ -247,7 +253,7 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
247
253
if __name__ == '__main__' :
248
254
np .set_printoptions (precision = 3 )
249
255
250
- order = 3
256
+ order = 2
251
257
subgroups_ids = dict (C2 = ('cone' , 1 ),
252
258
Tetrahedral = ('fulltetra' ,),
253
259
Octahedral = (True , 'octa' ,),
@@ -264,7 +270,8 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
264
270
G , g_dynamics_2_Gsub_domain , g_domain_2_g_dynamics = G_domain .subgroup (G_id )
265
271
266
272
# Define the state representation.
267
- rep_X = G .standard_representation () # + G.irrep(1)
273
+ # rep_X = G.regular_representation # + G.irrep(1)
274
+ rep_X = G .irrep (0 ) + G .standard_representation () # + G.irrep(1)
268
275
# rep_X = G.irrep(1) + G.irrep(2) #+ G.irrep(1) #+ G.irrep(0)
269
276
#
270
277
# Generate stable equivariant linear dynamics withing a range of fast and slow dynamics
@@ -279,11 +286,11 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
279
286
T = fastest_period # Simulate until the slowest stable mode has completed a full period.
280
287
else : # System has transient dynamics that vanish to 36.8% in fastest_time_constant seconds.
281
288
T = 6 * fastest_time_constant # Required time for this transient dynamics to vanish.
282
- dt = T * 0.005 # Sample time to obtain 200 samples per trajectory
289
+ dt = T * 0.005 # Sample time to obtain 100 samples per trajectory
283
290
284
291
# Generate trajectories of the system dynamics
285
292
n_constraints = 0
286
- n_trajs = 100
293
+ n_trajs = 120
287
294
# Generate hyperplanes that constraint outer region of space
288
295
P_symm , offset = None , None
289
296
if n_constraints > 0 :
@@ -293,12 +300,12 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
293
300
for normal_plane in normal_planes :
294
301
normal_orbit = np .vstack ([np .linalg .det (rep_X (g )) * (rep_X (g ) @ normal_plane ) for g in G .elements ])
295
302
# Fix point of linear systems is the origin
296
- offset_orbit = np .asarray ([- np .random .uniform (- 0.05 , 0.6 )] * normal_orbit .shape [0 ])
303
+ offset_orbit = np .asarray ([- np .random .uniform (- 0.1 , 0.3 )] * normal_orbit .shape [0 ])
297
304
P_symm = np .vstack ((P_symm , normal_orbit )) if P_symm is not None else normal_orbit
298
305
offset = np .concatenate ((offset , offset_orbit )) if offset is not None else offset_orbit
299
306
300
307
trajs_per_noise_level = []
301
- for noise_level in range (10 ):
308
+ for noise_level in tqdm ( range (10 ), desc = "noise level" ):
302
309
sigma = T * 0.005 * noise_level
303
310
state_trajs = []
304
311
for _ in range (n_trajs ):
@@ -380,9 +387,12 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
380
387
fig = fig , constraint_matrix = P_symm , constraint_offset = offset ,
381
388
traj_colorscale = 'Agsunset' , init_state_color = 'yellow' ,
382
389
legendgroup = "val" )
390
+ else :
391
+ pass
383
392
384
- fig .write_html (path_2_system / 'test_trajectories.html' )
385
- if noise_level == 0 and fig is not None :
393
+ if fig is not None :
394
+ fig .write_html (path_2_system / 'test_trajectories.html' )
395
+ if noise_level == 1 and fig is not None :
386
396
fig .show ()
387
397
# fig.show()
388
398
print (f"Recordings saved to { path_2_system } " )
0 commit comments