Skip to content

Commit 2ce6c44

Browse files
committed
Intro LatentMarkovDynamics
1 parent 5f66aff commit 2ce6c44

20 files changed

+945
-825
lines changed

cfg/model/dpnet.yaml

+6-7
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@ defaults:
22
- base_model
33

44
name: DPNet
5-
# Model hyperparameters
65

76
# Symmetry exploitation parameters
87
equivariant: False
9-
group_avg_trick: True
108

11-
activation: ELU
12-
num_layers: 4 # Number MLPs' layers (including input and output layers)
13-
num_hidden_units: 128 # Number of hidden units in each layer
14-
batch_norm: ${model.equivariant} # Something wrong happens when we turn batch norm on. Performance goes to hell.
9+
# Model hyperparameters
10+
activation: ReLU
11+
num_layers: 5 # Number MLPs' layers (including input and output layers)
12+
num_hidden_units: 128 # Number of hidden units in each layer
13+
batch_norm: True
1514
bias: False
1615
# Optimization hyperparameters parameters
1716
lr: 1e-3
@@ -23,5 +22,5 @@ orth_w: 0.5 # Weight of the orthonormal regu
2322
aux_obs_space: False # Whether to use an auxiliary observable space.
2423
use_spectral_score: True # Whether to use the spectral or the correlation score
2524

26-
summary: ${model.name}-Equiv:${model.equivariant}-CK_w:${model.ck_w}-Orth_w:${model.orth_w}-Win:${model.max_ck_window_length}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}-AOS:${model.aux_obs_space}-SS:${model.use_spectral_score}
25+
summary: ${model.name}-CK_w:${model.ck_w}-Orth_w:${model.orth_w}-Win:${model.max_ck_window_length}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}-AOS:${model.aux_obs_space}-SS:${model.use_spectral_score}
2726

cfg/model/edpnet.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
defaults:
2+
- dpnet
3+
4+
name: E-DPNet
5+
6+
# Symmetry exploitation parameters
7+
equivariant: True
8+
group_avg_trick: True
9+
10+
activation: ELU

cfg/system/linear_system.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@ defaults:
33

44
name: 'linear_system'
55

6-
state_dim: 3 # Dimension of the system's state
7-
obs_state_dim: 3 # Dimension of the system's observable state
6+
state_dim: 2 # Dimension of the system's state
7+
obs_state_dim: 2 # Dimension of the system's observable state
88

99
frames_per_state: 1 # Number of time-frames to use as a Markov Process state time step
1010
pred_horizon: 10 # Number (or percentage) of Markov Process state time steps to predict into the future
1111
eval_pred_horizon: 100 # Number (or percentage) of Markov Process state time steps to predict into the future
1212

13-
group: C3
13+
group: SO(2)
1414
noise_level: 0
1515
n_constraints: 0
1616

nn/DPNet.py

+155-284
Large diffs are not rendered by default.

nn/EquivDPNet.py

+142-181
Large diffs are not rendered by default.

nn/EquivDynamicsAutoencoder.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch
77
from escnn.nn import FieldType, GeometricTensor
88

9-
from nn.LinearDynamics import EquivariantLinearDynamics
10-
from nn.markov_dynamics import MarkovDynamicsModule
9+
from nn.LinearDynamics import LinearDynamics
10+
from nn.markov_dynamics import MarkovDynamics
1111
from nn.mlp import MLP
1212
from nn.emlp import EMLP
1313
from utils.representation_theory import isotypic_basis
@@ -39,7 +39,7 @@ def compute_invariant_features(x: torch.Tensor, field_type: FieldType) -> torch.
3939
return inv_features
4040

4141

42-
class EquivDynamicsAutoEncoder(MarkovDynamicsModule):
42+
class EquivDynamicsAutoEncoder(MarkovDynamics):
4343
TIME_DIM = 1
4444

4545
def __init__(self,
@@ -74,12 +74,12 @@ def __init__(self,
7474
num_hidden_units=num_encoder_hidden_neurons,
7575
num_layers=num_encoder_layers,
7676
activation=activation,
77-
with_bias=True)
77+
bias=True)
7878
self.decoder = EMLP(in_type=self.obs_state_type,
7979
out_type=self.state_type,
8080
num_hidden_units=num_encoder_hidden_neurons,
8181
num_layers=num_encoder_layers,
82-
with_bias=True)
82+
bias=True)
8383
# Define the linear dynamics module.
8484
self.obs_state_dynamics = EquivariantLinearDynamics(in_type=self.obs_state_type,
8585
dt=self.dt,
@@ -100,7 +100,7 @@ def __init__(self,
100100
num_hidden_units=num_encoder_hidden_neurons,
101101
num_layers=3,
102102
activation=torch.nn.ReLU,
103-
with_bias=True)
103+
bias=True)
104104
raise NotImplementedError("TODO: Need to implement this. "
105105
"There is no easy way to get batched parametrization of equivariant maps")
106106

nn/EquivLinearDynamics.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from collections import OrderedDict
2+
from typing import Optional, Union
3+
4+
import numpy as np
5+
import torch
6+
from escnn.group import Representation
7+
from torch import Tensor
8+
9+
from nn.LinearDynamics import DmdSolver, LinearDynamics
10+
from nn.markov_dynamics import MarkovDynamics
11+
from utils.mysc import full_rank_lstsq, full_rank_lstsq_symmetric
12+
from utils.representation_theory import isotypic_basis
13+
14+
15+
class EquivLinearDynamics(LinearDynamics):
16+
17+
def __init__(self,
18+
state_rep: Representation = None,
19+
dmd_algorithm: Optional[DmdSolver] = None,
20+
dt: Optional[Union[float, int]] = 1,
21+
trainable=False,
22+
group_avg_trick: bool = True):
23+
24+
self.symm_group = state_rep.group
25+
self.group_avg_trick = group_avg_trick
26+
# Find the Isotypic basis of the state space
27+
self.state_iso_reps, self.state_iso_dims, Q_iso2state = isotypic_basis(representation=state_rep,
28+
multiplicity=1,
29+
prefix='ELDstate')
30+
# Change of coordinates required for state to be in Isotypic basis.
31+
Q_iso2state = Tensor(Q_iso2state)
32+
Q_state2iso = Tensor(np.linalg.inv(Q_iso2state))
33+
34+
self.iso_transfer_op = OrderedDict()
35+
for irrep_id in self.state_iso_reps: # Preserve the order of the Isotypic Subspaces
36+
self.iso_transfer_op[irrep_id] = None
37+
38+
self.is_trainable = trainable
39+
dmd_algorithm = dmd_algorithm if dmd_algorithm is not None else full_rank_lstsq_symmetric
40+
super(EquivLinearDynamics, self).__init__(state_rep=state_rep,
41+
dt=dt,
42+
dmd_algorithm=dmd_algorithm,
43+
state_change_of_basis=Q_state2iso,
44+
state_inv_change_of_basis=Q_iso2state)
45+
46+
def update_transfer_op(self, X: Tensor, X_prime: Tensor, group_avg_trick: bool = True):
47+
""" Use a DMD algorithm to update the empirical transfer operator
48+
Args:
49+
X: (state_dim, n_samples) Data matrix of states at time `t`.
50+
X_prime: (state_dim, n_samples) Data matrix of the states at time `t + dt`.
51+
group_avg_trick: (bool) Whether to use the group average trick to enforce equivariance.
52+
"""
53+
if self.is_trainable:
54+
raise RuntimeError("This model was initialized as trainable")
55+
assert X.shape == X_prime.shape, f"X: {X.shape}, X_prime: {X_prime.shape}"
56+
assert X.shape[0] == self.state_dim, f"Invalid state dimension {X.shape[0]} != {self.state_dim}"
57+
58+
state, next_state = X.T, X_prime.T
59+
iso_rec_error = []
60+
# For each Isotypic Subspace, compute the empirical transfer operator.
61+
for irrep_id, iso_rep in self.state_iso_reps.items():
62+
rep = iso_rep if irrep_id != self.symm_group.identity else None # Check for Trivial Subspace
63+
64+
# IsoSpace
65+
# Get the projection of the state onto the isotypic subspace
66+
state_iso = state[..., self.state_iso_dims[irrep_id]]
67+
next_state_iso = next_state[..., self.state_iso_dims[irrep_id]]
68+
69+
# Generate the data matrices of x(w_t) and x(w_t+1)
70+
X_iso = state_iso.T # (iso_state_dim, num_samples)
71+
X_iso_prime = next_state_iso.T # (iso_state_dim, num_samples)
72+
73+
# Compute the empirical transfer operator of this Observable Isotypic subspace
74+
A_iso = self.dmd_algorithm(X_iso, X_iso_prime,
75+
rep_X=rep if self.group_avg_trick else None,
76+
rep_Y=rep if self.group_avg_trick else None)
77+
rec_error = torch.nn.functional.mse_loss(A_iso @ X_iso, X_iso_prime)
78+
iso_rec_error.append(rec_error)
79+
self.iso_transfer_op[irrep_id] = A_iso
80+
transfer_op = torch.block_diag(*[self.iso_transfer_op[irrep_id] for irrep_id in self.state_iso_reps.keys()])
81+
assert transfer_op.shape == (self.state_dim, self.state_dim)
82+
self.transfer_op = transfer_op
83+
84+
iso_rec_error = Tensor(iso_rec_error)
85+
rec_error = torch.sum(iso_rec_error)
86+
self.transfer_op = transfer_op
87+
88+
return dict(solution_op_rank=torch.linalg.matrix_rank(transfer_op.detach()).to(torch.float),
89+
solution_op_cond_num=torch.linalg.cond(transfer_op.detach()).to(torch.float),
90+
solution_op_error=rec_error.detach().to(torch.float),
91+
solution_op_error_dist=iso_rec_error.detach().to(torch.float))

nn/LightningModel.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from lightning import LightningModule
1414
from lightning.pytorch.utilities.types import STEP_OUTPUT
1515

16-
from nn.markov_dynamics import MarkovDynamicsModule
16+
from nn.markov_dynamics import MarkovDynamics
1717

1818
log = logging.getLogger(__name__)
1919

@@ -48,7 +48,7 @@ def __init__(self,
4848
self.save_hyperparameters()
4949
self._log_cache = {}
5050

51-
def set_model(self, model: MarkovDynamicsModule):
51+
def set_model(self, model: MarkovDynamics):
5252
self.model = model
5353
if hasattr(model, 'eval_metrics'):
5454
self.test_metrics_fn = model.eval_metrics
@@ -59,8 +59,7 @@ def forward(self, batch):
5959
return self.model(inputs)
6060

6161
def training_step(self, batch, batch_idx):
62-
n_steps = batch['next_state'].shape[1]
63-
outputs = self.model(**batch, n_steps=n_steps)
62+
outputs = self.model(**batch)
6463
loss, metrics = self.model.compute_loss_and_metrics(**outputs, **batch)
6564
vector_metrics, scalar_metrics = self.separate_vector_scalar_metrics(metrics)
6665

@@ -70,8 +69,7 @@ def training_step(self, batch, batch_idx):
7069
return loss
7170

7271
def validation_step(self, batch, batch_idx):
73-
n_steps = batch['next_state'].shape[1]
74-
outputs = self.model(**batch, n_steps=n_steps)
72+
outputs = self.model(**batch)
7573
loss, metrics = self.model.compute_loss_and_metrics(**outputs, **batch)
7674
vector_metrics, scalar_metrics = self.separate_vector_scalar_metrics(metrics)
7775

@@ -81,9 +79,7 @@ def validation_step(self, batch, batch_idx):
8179
return {'output': outputs, 'input': batch}
8280

8381
def test_step(self, batch, batch_idx):
84-
n_steps = batch['next_state'].shape[1]
85-
outputs = self.model(**batch, n_steps=n_steps)
86-
82+
outputs = self.model(**batch)
8783
loss, metrics = self.model.compute_loss_and_metrics(**outputs, **batch)
8884
vector_metrics, scalar_metrics = self.separate_vector_scalar_metrics(metrics)
8985

@@ -122,8 +118,8 @@ def on_fit_start(self) -> None:
122118
def on_train_start(self):
123119
# TODO: Add number of layers and hidden channels dimensions.
124120
hparams = flatten_dict(self._run_hps)
125-
if hasattr(self.model, "get_hparams"):
126-
hparams.update(flatten_dict(self.model.get_hparams()))
121+
# if hasattr(self.model, "get_hparams"):
122+
# hparams.update(flatten_dict(self.model.get_hparams()))
127123

128124
if self.val_metrics_fn is not None:
129125
self.compute_figure_metrics(self.val_metrics_fn, self.trainer.datamodule.train_dataloader(), suffix="train")
@@ -247,8 +243,7 @@ def log_figures(self, figs: dict[str, plotly.graph_objs.Figure], suffix=''):
247243
@torch.no_grad()
248244
def compute_figure_metrics(self, metrics_fn: Callable, dataloader, suffix=''):
249245
batch = next(iter(dataloader))
250-
n_steps = batch['next_state'].shape[1]
251-
outputs = self.model(**batch, n_steps=n_steps)
246+
outputs = self.model(**batch)
252247

253248
figs, metrics = metrics_fn(**outputs, **batch)
254249

0 commit comments

Comments
 (0)