Skip to content

Commit b8bb708

Browse files
committed
Equiv-DPnets working on high dimensional spaces.
1 parent 99c10f2 commit b8bb708

17 files changed

+104
-116
lines changed

cfg/model/dae.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ defaults:
33

44
name: DAE
55
# Model hyperparameters
6-
obs_pred_w: 1.0 # Cost function weight for prediction in observation space Z
6+
obs_pred_w: 5.0 # Cost function weight for prediction in observation space Z
77
orth_w: 1.0 # Weight of the orthonormal regularization term in the loss function
88
corr_w: 0.0
99

1010
# Optimization hyperparameters parameters
1111
lr: 1e-3
12-
batch_size: 1024
12+
actiavtion: ELU
1313
equivariant: False
1414

1515
summary: ${model.name}-Obs_w:${model.obs_pred_w}-Orth_w:${model.orth_w}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}

cfg/model/dpnet.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@ name: DPNet
77
equivariant: False
88

99
# Model hyperparameters
10-
activation: ReLU
10+
activation: ELU
1111
num_layers: 5 # Number MLPs' layers (including input and output layers)
1212
num_hidden_units: 128 # Number of hidden units in each layer
1313
batch_norm: True
1414
bias: False
1515
# Optimization hyperparameters parameters
1616
lr: 1e-3
17-
batch_size: 1024
1817

1918
max_ck_window_length: ${system.pred_horizon} # Maximum length of the Chapman-Kolmogorov window
2019
ck_w: 0.0 # Weight of the Chapman-Kolmogorov regularization term in the loss function

cfg/model/edae.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ corr_w: 0.0
99
state_dependent_obs_dyn: False # Whether to use state-dependent observation dynamics
1010
group_avg_trick: True
1111

12+
13+
actiavtion: ELU
1214
equivariant: True
1315

nn/DeepProjections.py

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def pre_process_obs_state(self,
129129
obs_state_traj_aux = super().pre_process_obs_state(obs_state_traj_aux)['obs_state_traj']
130130
return dict(obs_state_traj=obs_state_traj, obs_state_traj_aux=obs_state_traj_aux)
131131

132+
132133
def compute_loss_and_metrics(self,
133134
obs_state_traj: Tensor,
134135
obs_state_traj_aux: Tensor,

nn/DynamicsAutoEncoder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def forecast(self, state: Tensor, n_steps: int = 1, **kwargs) -> [dict[str, Tens
8787
f"{pred_obs_state_traj.shape}!=({self._batch_size}, {time_horizon}, {self.obs_state_dim})"
8888
return pred_state_traj, pred_obs_state_traj
8989

90-
def post_process_obs_state(self, pred_state_traj: Tensor, pred_state_one_step: Tensor) -> dict[str, Tensor]:
90+
def post_process_obs_state(self, obs_state_traj: Tensor, pred_state_one_step: Tensor) -> dict[str, Tensor]:
9191
""" Post-process the predicted observable state trajectory given by the observable state dynamics.
9292
9393
Args:
94-
pred_state_traj: (batch, time, obs_state_dim) Trajectory of the predicted (time -1) observable states
94+
obs_state_traj: (batch, time, obs_state_dim) Trajectory of the predicted (time -1) observable states
9595
predicted by the transfer operator.
9696
pred_state_one_step: (batch, time, obs_state_dim) Trajectory of the predicted one-step ahead (time)
9797
observable states predicted by the transfer operator.
@@ -101,7 +101,7 @@ def post_process_obs_state(self, pred_state_traj: Tensor, pred_state_one_step: T
101101
- pred_obs_state_traj: (batch * time, obs_state_dim) Trajectory
102102
- pred_obs_state_one_step: (batch, time, obs_state_dim) Trajectory
103103
"""
104-
batched_pred_obs_state_traj = batched_to_flat_trajectory(pred_state_traj)
104+
batched_pred_obs_state_traj = batched_to_flat_trajectory(obs_state_traj)
105105
return dict(pred_obs_state_traj=batched_pred_obs_state_traj,
106106
pred_obs_state_one_step=pred_state_one_step)
107107

nn/EquivDeepPojections.py

+23-36
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from data.DynamicsDataModule import DynamicsDataModule
2222
from nn.DeepProjections import DPNet
2323
from nn.EquivLinearDynamics import EquivLinearDynamics
24-
from nn.TwinMLP import TwinMLP
24+
from nn.ObservableNet import ObservableNet
2525
from nn.emlp import EMLP
2626
from nn.markov_dynamics import MarkovDynamics
2727
from utils.losses_and_metrics import forecasting_loss_and_metrics, obs_state_space_metrics
@@ -38,7 +38,7 @@ class EquivDPNet(DPNet):
3838
activation="p_elu",
3939
batch_norm=True,
4040
bias=False,
41-
backbone_layers=-2 # num_layers - 2
41+
# backbone_layers=-2 # num_layers - 2
4242
)
4343

4444
def __init__(self,
@@ -257,40 +257,27 @@ def empirical_lin_inverse_projector(self, state: Tensor, obs_state: Tensor):
257257
return A, metrics
258258

259259
def build_obs_fn(self, num_layers, **kwargs):
260-
num_backbone_layers = kwargs.pop('backbone_layers', num_layers - 2 if self.aux_obs_space else 0)
261-
if num_backbone_layers < 0:
262-
num_backbone_layers = num_layers - num_backbone_layers
263-
backbone_params = None
264-
if num_backbone_layers > 0 and self.aux_obs_space:
265-
num_hidden_units = kwargs.get('num_hidden_units')
266-
activation_type = kwargs.pop('activation')
267-
num_hidden_regular_fields = int(np.ceil(num_hidden_units // self.state_type_iso.size))
268-
act = EMLP.get_activation(activation=activation_type,
269-
in_type=self.state_type_iso,
270-
channels=num_hidden_regular_fields)
271-
backbone_params = dict(in_type=self.state_type_iso,
272-
out_type=act.out_type,
273-
activation=act,
274-
num_layers=num_backbone_layers,
275-
head_with_activation=True,
276-
**copy.copy(kwargs))
277-
kwargs['bias'] = False
278-
kwargs['batch_norm'] = False
279-
obs_fn_params = dict(in_type=act.out_type, out_type=self.obs_state_type,
280-
num_layers=num_layers - num_backbone_layers,
281-
activation=act,
282-
head_with_activation=False, **kwargs)
283-
else:
284-
obs_fn_params = dict(in_type=self.state_type_iso,
285-
out_type=self.obs_state_type,
286-
num_layers=num_layers,
287-
head_with_activation=False,
288-
**kwargs)
289260

290-
return TwinMLP(net_kwargs=obs_fn_params,
291-
backbone_kwargs=backbone_params,
292-
fake_aux_fn=not self.aux_obs_space,
293-
equivariant=True)
261+
num_hidden_units = kwargs.get('num_hidden_units')
262+
activation_type = kwargs.pop('activation')
263+
act = EMLP.get_activation(activation=activation_type,
264+
in_type=self.state_type_iso,
265+
desired_hidden_units=num_hidden_units)
266+
267+
obs_fn = EMLP(in_type=self.state_type_iso,
268+
out_type=self.obs_state_type,
269+
num_layers=num_layers,
270+
activation=act,
271+
**kwargs)
272+
obs_fn_aux = None
273+
if self.aux_obs_space:
274+
obs_fn_aux = EMLP(in_type=self.state_type_iso,
275+
out_type=self.obs_state_type,
276+
num_layers=num_layers,
277+
activation=act,
278+
**kwargs)
279+
280+
return ObservableNet(obs_fn=obs_fn, obs_fn_aux=obs_fn_aux)
294281

295282
def build_inv_obs_fn(self, num_layers, linear_decoder: bool, **kwargs):
296283
if linear_decoder:
@@ -307,7 +294,7 @@ def decoder(dpnet: DPNet, obs_state: Tensor):
307294
**kwargs)
308295

309296
def build_obs_dyn_module(self) -> MarkovDynamics:
310-
return EquivLinearDynamics(state_rep=self.obs_state_type.representation,
297+
return EquivLinearDynamics(state_type=self.obs_state_type,
311298
dt=self.dt,
312299
trainable=False,
313300
group_avg_trick=self.group_avg_trick)

nn/EquivDynamicsAutoencoder.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,19 @@ def pre_process_obs_state(self,
107107
) -> dict[str, Tensor]:
108108
return super().pre_process_obs_state(obs_state_traj.tensor)
109109

110-
def post_process_obs_state(self, pred_state_traj: Tensor, **kwargs) -> dict[str, GeometricTensor]:
110+
def post_process_obs_state(self, obs_state_traj: Tensor, **kwargs) -> dict[str, GeometricTensor]:
111111
""" Post-process the predicted observable state trajectory given by the observable state dynamics.
112112
113113
Args:
114-
pred_state_traj: (batch, time, obs_state_dim) Trajectory of the predicted (time -1) observable states
114+
obs_state_traj: (batch, time, obs_state_dim) Trajectory of the predicted (time -1) observable states
115115
predicted by the transfer operator.
116116
**kwargs:
117117
Returns:
118118
Dictionary contraining
119119
- pred_obs_state_traj: (batch * time, obs_state_dim) Geometric Tensor Trajectory
120120
"""
121-
flat_pred_obs_state_traj = batched_to_flat_trajectory(pred_state_traj)
122-
return dict(pred_obs_state_traj=self.obs_state_type(flat_pred_obs_state_traj))
121+
flat_obs_state_traj = batched_to_flat_trajectory(obs_state_traj)
122+
return dict(obs_state_traj=self.obs_state_type(flat_obs_state_traj))
123123

124124
def post_process_state(self, state_traj: GeometricTensor) -> Tensor:
125125
state_traj_input_basis = super().post_process_state(state_traj=state_traj.tensor)
@@ -138,7 +138,7 @@ def build_inv_obs_fn(self, num_layers: int, **kwargs):
138138
**kwargs)
139139

140140
def build_obs_dyn_module(self) -> MarkovDynamics:
141-
return EquivLinearDynamics(state_type=self.state_type_iso,
141+
return EquivLinearDynamics(state_type=self.obs_state_type,
142142
dt=self.dt,
143143
trainable=True,
144144
group_avg_trick=self.group_avg_trick)

nn/EquivLinearDynamics.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def __init__(self,
4646
Q_iso2state = Tensor(Q_iso2state)
4747
Q_state2iso = Tensor(np.linalg.inv(Q_iso2state))
4848

49-
super(EquivLinearDynamics, self).__init__(state_rep=state_type.representation,
49+
super(EquivLinearDynamics, self).__init__(state_dim=state_type.size,
50+
state_rep=state_type.representation,
5051
dt=dt,
5152
trainable=trainable,
5253
dmd_algorithm=dmd_algorithm,
@@ -77,13 +78,10 @@ def forcast(self, state: GeometricTensor, n_steps: int = 1, **kwargs) -> Tensor:
7778
next_obs_state = self.transfer_op(current_state)
7879
else:
7980
transfer_op = self.get_transfer_op()
80-
next_obs_state = torch.nn.functional.linear(current_state, transfer_op)
81+
next_obs_state = self.state_type((transfer_op @ current_state.tensor.T).T)
8182
pred_state_traj.append(next_obs_state)
8283

83-
if self.is_trainable:
84-
pred_state_traj = torch.stack([gt.tensor for gt in pred_state_traj], dim=1)
85-
else:
86-
pred_state_traj = torch.stack(pred_state_traj, dim=1)
84+
pred_state_traj = torch.stack([gt.tensor for gt in pred_state_traj], dim=1)
8785
assert pred_state_traj.shape == (batch, n_steps + 1, state_dim)
8886
return pred_state_traj
8987

nn/LightningLatentMarkovDynamics.py

-4
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,6 @@ def on_fit_start(self) -> None:
116116
self._loss_metrics_fn = loss_metrics_fn
117117

118118
def on_train_start(self):
119-
# TODO: Add number of layers and hidden channels dimensions.
120-
hparams = flatten_dict(self._run_hps)
121-
if hasattr(self.model, "get_hparams"):
122-
hparams.update(flatten_dict(self.model.get_hparams()))
123119

124120
if hasattr(self.model, "approximate_transfer_operator"):
125121
metrics = self.model.approximate_transfer_operator(self.trainer.datamodule.predict_dataloader())

nn/LinearDynamics.py

-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def forcast(self, state: Tensor, n_steps: int = 1, **kwargs) -> Tensor:
9191
pred_state_traj.append(next_obs_state)
9292

9393
pred_state_traj = torch.stack(pred_state_traj, dim=1)
94-
# a = pred_state_traj.detach().cpu().numpy()
9594
assert pred_state_traj.shape == (batch, n_steps + 1, state_dim)
9695
return pred_state_traj
9796

nn/ObservableNet.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,46 @@
1-
from typing import Optional
1+
import copy
2+
from typing import Optional, Union
23

4+
import escnn.nn
35
import torch.nn
6+
from escnn.nn import EquivariantModule
47

8+
from nn.EquivLinearDynamics import EquivLinearDynamics
9+
from nn.LinearDynamics import LinearDynamics
510
from nn.mlp import MLP
611
from nn.emlp import EMLP
712

813

9-
class TwinMLP(torch.nn.Module):
10-
"""Auxiliary class to construct Twin MLPs with a potentially shared backbone."""
14+
class ObservableNet(torch.nn.Module):
1115

12-
def __init__(self, net_kwargs: dict, backbone_kwargs: Optional[dict] = None, equivariant=False, fake_aux_fn=False):
16+
def __init__(self,
17+
obs_fn: Union[torch.nn.Module, EquivariantModule],
18+
obs_fn_aux: Optional[Union[torch.nn.Module, EquivariantModule]] = None):
1319
super().__init__()
14-
self.fake_aux_fn = fake_aux_fn
15-
self.shared_backbone = backbone_kwargs is not None
16-
mlp_class = MLP if not equivariant else EMLP # SO2MLP
20+
self.equivariant = isinstance(obs_fn, EquivariantModule)
21+
self.use_aux_obs_fn = obs_fn_aux is not None
1722

18-
if self.shared_backbone:
19-
self.backbone = mlp_class(**backbone_kwargs)
20-
21-
self.fn1 = mlp_class(**net_kwargs)
22-
if not fake_aux_fn:
23-
self.fn2 = mlp_class(**net_kwargs)
23+
self.obs = obs_fn
24+
self.obs_aux = None
25+
if self.use_aux_obs_fn: # Use two twin networks to compute the main and auxiliary observable space.
26+
self.obs_aux = obs_fn_aux
2427
else:
25-
pass
28+
if self.equivariant:
29+
self.transfer_op_H_H_prime = escnn.nn.Linear(
30+
in_type=self.obs.out_type, out_type=self.obs.out_type, bias=False)
31+
else:
32+
self.transfer_op_H_H_prime = torch.nn.Linear(
33+
in_features=self.obs.out_dim, out_features=self.obs.out_dim, bias=False)
2634

2735
def forward(self, input):
2836

29-
if self.shared_backbone:
30-
backbone_output = self.backbone(input)
31-
output1 = self.fn1(backbone_output)
32-
output2 = self.fn2(backbone_output)
33-
else:
34-
if self.fake_aux_fn:
35-
output1 = self.fn1(input)
36-
output2 = output1
37-
else:
38-
output1 = self.fn1(input)
39-
output2 = self.fn2(input)
37+
obs_state = self.obs(input)
4038

41-
return output1, output2
39+
if self.use_aux_obs_fn:
40+
obs_aux_state = self.obs_aux(input)
41+
else:
42+
obs_aux_state = self.transfer_op_H_H_prime(obs_state)
4243

43-
def get_hparams(self):
44-
return {}
44+
return obs_state, obs_aux_state
4545

4646

nn/emlp.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,11 @@ def __init__(self,
5858

5959
if isinstance(activation, str):
6060
# Approximate the num of neurons as the num of signals in the space spawned by the irreps of the input type
61-
self.num_hidden_regular_fields = int(np.ceil(num_hidden_units // self.in_type.size))
6261
# To compute the signal over the group we use all elements for finite groups
63-
activation = self.get_activation(activation, in_type=in_type, channels=self.num_hidden_regular_fields)
62+
activation = self.get_activation(activation, in_type=in_type, desired_hidden_units=num_hidden_units)
6463
hidden_type = activation.in_type
65-
self.activation = activation
6664
elif isinstance(activation, EquivariantModule):
6765
hidden_type = activation.in_type
68-
self.activation = activation
6966
else:
7067
raise ValueError(f"Activation type {type(activation)} not supported.")
7168

@@ -104,17 +101,22 @@ def __init__(self,
104101
# self.net.check_equivariance()
105102

106103
@staticmethod
107-
def get_activation(activation, in_type: FieldType, channels: int):
104+
def get_activation(activation, in_type: FieldType, desired_hidden_units: int):
108105
gspace = in_type.gspace
109106
group = gspace.fibergroup
110107
grid_length = group.order() if not group.continuous else 20
108+
109+
unique_irreps = set(in_type.irreps)
110+
unique_irreps_dim = sum([group.irrep(*id).size for id in set(in_type.irreps)])
111+
scale = in_type.size // unique_irreps_dim
112+
channels = int(np.ceil(desired_hidden_units // unique_irreps_dim // scale))
111113
if "identity" in activation.lower():
112114
raise NotImplementedError("Identity activation not implemented yet")
113115
# return escnn.nn.IdentityModule()
114116
else:
115117
return escnn.nn.FourierPointwise(gspace,
116118
channels=channels,
117-
irreps=in_type.irreps,
119+
irreps=list(unique_irreps),
118120
function=f"p_{activation.lower()}",
119121
inplace=True,
120122
type='regular' if not group.continuous else 'rand',
@@ -124,12 +126,6 @@ def forward(self, x):
124126
"""Forward pass of the EMLP model."""
125127
return self.net(x)
126128

127-
def get_hparams(self):
128-
return {'num_layers': self.num_layers,
129-
'hidden_ch': self.num_hidden_regular_fields,
130-
'activation': str(self.activation.__class__.__name__),
131-
}
132-
133129
def reset_parameters(self, init_mode=None):
134130
"""Initialize weights and biases of E-MLP model."""
135131
raise NotImplementedError()

0 commit comments

Comments
 (0)