Skip to content

Commit 610ad1f

Browse files
committed
Equiv-DAE operational
This commit restructures the DPNET and DAE algorithms under the same base class LatendMarkovDynamics module. Both DPNet and DAE have their own equivariant implementations operational in this commit.
1 parent 2ce6c44 commit 610ad1f

19 files changed

+892
-462
lines changed

cfg/config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ hydra:
6767
- system.pred_horizon
6868
- system.eval_pred_horizon
6969
- system.data_path
70+
- system.group
71+
- system.noise_level
7072
- debug
7173
- debug_loops
7274
- seed

cfg/model/base_model.yaml

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
name: ??
44
# Dataset related parameters
55
augment: False # Weather to use data-augmentation of the dataset if the system has symmetries.
6-
# Architectural and Task parameters
7-
equivariant: True # Impose equivariance constraints on NNs
6+
# Model hyperparameters
87
activation: ReLU
9-
num_layers: 4 # Number MLPs' layers (including input and output layers)
10-
bias: True # Use bias in the MLPs
11-
batch_norm: True # Use batch normalization layers before activation functions in MLPs
8+
num_layers: 5 # Number MLPs' layers (including input and output layers)
9+
num_hidden_units: 128 # Number of hidden units in each layer
10+
batch_norm: True
11+
bias: False
1212

1313
# Optimization hyperparameters parameters
1414
lr: 1e-3

cfg/model/dae.yaml

+5-6
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@ defaults:
33

44
name: DAE
55
# Model hyperparameters
6-
loss_pred_w: 0.01 # Cost function weight for prediction in observation space Z
7-
equivariant: True
8-
activation: ReLU
9-
n_layers: 4 # Number MLPs' layers (including input and output layers)
6+
obs_pred_w: 1.0 # Cost function weight for prediction in observation space Z
7+
orth_w: 1.0 # Weight of the orthonormal regularization term in the loss function
8+
corr_w: 0.0
109

1110
# Optimization hyperparameters parameters
1211
lr: 1e-3
1312
batch_size: 1024
13+
equivariant: False
1414

15-
#eigval_init: "stable" # Initialization of eigenvalues [stable, ]
16-
#eigval_constraint: "unconstrained" # [unconstrained or unit_circle]
15+
summary: ${model.name}-Obs_w:${model.obs_pred_w}-Orth_w:${model.orth_w}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}

cfg/model/edae.yaml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
defaults:
2+
- dae
3+
4+
name: E-DAE
5+
# Model hyperparameters
6+
obs_pred_w: 1.0 # Cost function weight for prediction in observation space Z
7+
orth_w: 1.0 # Weight of the orthonormal regularization term in the loss function
8+
corr_w: 0.0
9+
state_dependent_obs_dyn: False # Whether to use state-dependent observation dynamics
10+
group_avg_trick: True
11+
12+
equivariant: True
13+

cfg/system/linear_system.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ obs_state_dim: 2 # Dimension of the system's observable state
88

99
frames_per_state: 1 # Number of time-frames to use as a Markov Process state time step
1010
pred_horizon: 10 # Number (or percentage) of Markov Process state time steps to predict into the future
11-
eval_pred_horizon: 100 # Number (or percentage) of Markov Process state time steps to predict into the future
11+
eval_pred_horizon: .5 # Number (or percentage) of Markov Process state time steps to predict into the future
1212

1313
group: SO(2)
14-
noise_level: 0
14+
noise_level: 1
1515
n_constraints: 0
1616

17-
data_path: '${system.name}/group=${system.group}-dim=${system.state_dim}/'
17+
data_path: '${system.name}/group=${system.group}-dim=${system.state_dim}/n_constraints=${system.n_constraints}'
1818

1919
summary: S:${system.state_dim}-OS:${system.obs_state_dim}-H:${system.pred_horizon}-G:${system.group}-N:${system.noise_level}

dynamical_systems/stable_linear_system.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import scipy
88
from escnn.group import Representation
9+
from tqdm import tqdm
910

1011
from data.DynamicsRecording import DynamicsRecording
1112
from utils.mysc import companion_matrix, matrix_average_trick, random_orthogonal_matrix
@@ -26,6 +27,7 @@ def sample_initial_condition(state_dim, P=None, z=None):
2627
distance_from_origin = max(distance_from_origin, MIN_DISTANCE_FROM_ORIGIN) # Truncate unlikely low values
2728
x0 = distance_from_origin * direction
2829

30+
trials = 500
2931
if P is not None:
3032
violation = P @ x0 < z
3133
is_constraint_violated = np.any(violation)
@@ -42,6 +44,10 @@ def sample_initial_condition(state_dim, P=None, z=None):
4244

4345
violation = P @ x0 < z
4446
is_constraint_violated = np.any(violation)
47+
48+
trials -= 1
49+
if trials == 0:
50+
raise RuntimeError("Too constrained.")
4551
if np.linalg.norm(x0) < MIN_DISTANCE_FROM_ORIGIN: # If sample is too close to zero ignore it.
4652
x0 = sample_initial_condition(state_dim, P=P, z=z)
4753
return x0
@@ -152,7 +158,7 @@ def stable_equivariant_lin_dynamics(rep_X: Representation, time_constant=1, min_
152158
iso_state_dim = rep_iso.size
153159
A_iso = stable_lin_dynamics(rep_iso,
154160
time_constant=time_constant,
155-
stable_eigval_prob=1 / (iso_state_dim) if state_dim > 1 else 0.0,
161+
stable_eigval_prob=1 / (iso_state_dim + 1) if state_dim > 1 else 0.0,
156162
min_period=min_period,
157163
max_period=max_period)
158164
# Enforce G-equivariance
@@ -247,7 +253,7 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
247253
if __name__ == '__main__':
248254
np.set_printoptions(precision=3)
249255

250-
order = 3
256+
order = 2
251257
subgroups_ids = dict(C2=('cone', 1),
252258
Tetrahedral=('fulltetra',),
253259
Octahedral=(True, 'octa',),
@@ -264,7 +270,8 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
264270
G, g_dynamics_2_Gsub_domain, g_domain_2_g_dynamics = G_domain.subgroup(G_id)
265271

266272
# Define the state representation.
267-
rep_X = G.standard_representation() # + G.irrep(1)
273+
# rep_X = G.regular_representation # + G.irrep(1)
274+
rep_X = G.irrep(0) + G.standard_representation() # + G.irrep(1)
268275
# rep_X = G.irrep(1) + G.irrep(2) #+ G.irrep(1) #+ G.irrep(0)
269276
#
270277
# Generate stable equivariant linear dynamics withing a range of fast and slow dynamics
@@ -279,11 +286,11 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
279286
T = fastest_period # Simulate until the slowest stable mode has completed a full period.
280287
else: # System has transient dynamics that vanish to 36.8% in fastest_time_constant seconds.
281288
T = 6 * fastest_time_constant # Required time for this transient dynamics to vanish.
282-
dt = T * 0.005 # Sample time to obtain 200 samples per trajectory
289+
dt = T * 0.005 # Sample time to obtain 100 samples per trajectory
283290

284291
# Generate trajectories of the system dynamics
285292
n_constraints = 0
286-
n_trajs = 100
293+
n_trajs = 120
287294
# Generate hyperplanes that constraint outer region of space
288295
P_symm, offset = None, None
289296
if n_constraints > 0:
@@ -293,12 +300,12 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
293300
for normal_plane in normal_planes:
294301
normal_orbit = np.vstack([np.linalg.det(rep_X(g)) * (rep_X(g) @ normal_plane) for g in G.elements])
295302
# Fix point of linear systems is the origin
296-
offset_orbit = np.asarray([-np.random.uniform(-0.05, 0.6)] * normal_orbit.shape[0])
303+
offset_orbit = np.asarray([-np.random.uniform(-0.1, 0.3)] * normal_orbit.shape[0])
297304
P_symm = np.vstack((P_symm, normal_orbit)) if P_symm is not None else normal_orbit
298305
offset = np.concatenate((offset, offset_orbit)) if offset is not None else offset_orbit
299306

300307
trajs_per_noise_level = []
301-
for noise_level in range(10):
308+
for noise_level in tqdm(range(10), desc="noise level"):
302309
sigma = T * 0.005 * noise_level
303310
state_trajs = []
304311
for _ in range(n_trajs):
@@ -380,9 +387,12 @@ def evolve_linear_dynamics(A: np.ndarray, init_state: np.ndarray, dt: float, sim
380387
fig=fig, constraint_matrix=P_symm, constraint_offset=offset,
381388
traj_colorscale='Agsunset', init_state_color='yellow',
382389
legendgroup="val")
390+
else:
391+
pass
383392

384-
fig.write_html(path_2_system / 'test_trajectories.html')
385-
if noise_level == 0 and fig is not None:
393+
if fig is not None:
394+
fig.write_html(path_2_system / 'test_trajectories.html')
395+
if noise_level == 1 and fig is not None:
386396
fig.show()
387397
# fig.show()
388398
print(f"Recordings saved to {path_2_system}")

nn/DPNet.py nn/DeepProjections.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def __init__(
8989
dt=dt,
9090
**markov_dyn_params)
9191

92-
9392
def forecast(self, state: Tensor, n_steps: int = 1, **kwargs) -> [dict[str, Tensor]]:
9493
"""Forward pass of the dynamics model, producing a prediction of the next `n_steps` states.
9594
@@ -107,15 +106,33 @@ def forecast(self, state: Tensor, n_steps: int = 1, **kwargs) -> [dict[str, Tens
107106
obs_state = self.obs_fn(state)
108107
pred_obs_state_traj = self.obs_state_dynamics.forcast(state=obs_state, n_steps=n_steps)
109108
pred_state_traj = self.inv_obs_fn(pred_obs_state_traj)
110-
if self.transfer_op is None:
111-
raise RuntimeError("The transfer operator not approximated yet. Call `approximate_transfer_operator`")
112109

113110
assert pred_state_traj.shape == (self._batch_size, time_horizon, self.state_dim), \
114111
f"{pred_state_traj.shape}!=({self._batch_size}, {time_horizon}, {self.state_dim})"
115112
assert pred_obs_state_traj.shape == (self._batch_size, time_horizon, self.obs_state_dim), \
116113
f"{pred_obs_state_traj.shape}!=({self._batch_size}, {time_horizon}, {self.obs_state_dim})"
114+
raise NotImplementedError("This function needs to handle pre/post state processing")
117115
return pred_state_traj, pred_obs_state_traj
118116

117+
def pre_process_obs_state(self,
118+
obs_state_traj: Tensor,
119+
obs_state_traj_aux: Optional[Tensor] = None) -> dict[str, Tensor]:
120+
""" Apply transformations to the observable state trajectory.
121+
Args:
122+
obs_state_traj: (batch * time, obs_state_dim) or (batch, time, obs_state_dim)
123+
obs_state_traj_aux: (batch * time, obs_state_dim) or (batch, time, obs_state_dim)
124+
Returns:
125+
Directory containing
126+
- obs_state_traj: (batch, time, obs_state_dim) tensor.
127+
- obs_state_traj_aux: (batch, time, obs_state_dim) tensor.
128+
"""
129+
obs_state_traj = super().pre_process_obs_state(obs_state_traj)['obs_state_traj']
130+
if obs_state_traj_aux is not None:
131+
obs_state_traj_aux = super().pre_process_obs_state(obs_state_traj_aux)['obs_state_traj']
132+
else:
133+
obs_state_traj_aux = None
134+
return dict(obs_state_traj=obs_state_traj, obs_state_traj_aux=obs_state_traj_aux)
135+
119136
def compute_loss_and_metrics(self,
120137
obs_state_traj: Tensor,
121138
obs_state_traj_aux: Tensor,
@@ -277,7 +294,7 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
277294
# Obtain the observable state
278295
obs_fn_output = self.obs_fn(state_traj)
279296
# Post process observable state
280-
obs_state_trajs = self.post_process_obs_state(*obs_fn_output)
297+
obs_state_trajs = self.pre_process_obs_state(*obs_fn_output)
281298
obs_state_traj = obs_state_trajs.pop('obs_state_traj')
282299

283300
assert obs_state_traj.shape[1] == 2, f"Expected single step datapoints, got {obs_state_traj.shape[1]} steps."
@@ -320,7 +337,7 @@ def build_obs_fn(self, num_layers, identity=False, **kwargs):
320337

321338
return TwinMLP(net_kwargs=obs_fn_params, backbone_kwargs=backbone_params, fake_aux_fn=not self.aux_obs_space)
322339

323-
def build_inv_obs_fn(self, num_layers, linear_decoder: bool, identity=False, **kwargs):
340+
def build_inv_obs_fn(self, num_layers, linear_decoder: bool, identity=False, **kwargs):
324341
if identity:
325342
return lambda x: x
326343

@@ -385,7 +402,7 @@ def get_hparams(self):
385402
state_dim = 2
386403
obs_state_dim = state_dim
387404

388-
change_of_basis = None# Tensor(random_orthogonal_matrix(state_dim))
405+
change_of_basis = None # Tensor(random_orthogonal_matrix(state_dim))
389406

390407
test_dpnet = DPNet(state_dim=state_dim,
391408
obs_state_dim=obs_state_dim,
@@ -398,5 +415,5 @@ def get_hparams(self):
398415
pred_state_traj = out['pred_state_traj']
399416

400417
assert pred_state_traj.shape == random_state_traj.shape, f"{pred_state_traj.shape} != {random_state_traj.shape}"
401-
assert torch.allclose(pred_state_traj, random_state_traj, rtol=1e-5, atol=1e-5), f"{pred_state_traj - random_state_traj}"
402-
418+
assert torch.allclose(pred_state_traj, random_state_traj, rtol=1e-5,
419+
atol=1e-5), f"{pred_state_traj - random_state_traj}"

0 commit comments

Comments
 (0)