Skip to content

Commit 99c10f2

Browse files
committed
Fix the OLS bug
The number of datasamples used to compute the DMD least squares problem was hardcoded to be a low number generating understandably unstable and brittle performance.
1 parent 610ad1f commit 99c10f2

File tree

3 files changed

+34
-49
lines changed

3 files changed

+34
-49
lines changed

cfg/params/reg_exp.yaml

-6
This file was deleted.

nn/DeepProjections.py

+33-39
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from data.DynamicsDataModule import DynamicsDataModule
1818
from nn.LinearDynamics import LinearDynamics
19-
from nn.TwinMLP import TwinMLP
19+
from nn.ObservableNet import ObservableNet
2020
from nn.latent_markov_dynamics import LatentMarkovDynamics
2121
from nn.markov_dynamics import MarkovDynamics
2222
from nn.mlp import MLP
@@ -42,7 +42,6 @@ class DPNet(LatentMarkovDynamics):
4242
batch_norm=True,
4343
bias=False,
4444
init_mode='fan_in',
45-
backbone_layers=-2 # num_layers - 2
4645
)
4746

4847
def __init__(
@@ -73,11 +72,11 @@ def __init__(
7372
_obs_fn_params = self._default_obs_fn_params.copy()
7473
if obs_fn_params is not None:
7574
_obs_fn_params.update(obs_fn_params)
75+
# Define the linear dynamics of the observable state space
76+
obs_state_dym = self.build_obs_dyn_module()
7677
# Build the observation function and its inverse
7778
obs_fn = self.build_obs_fn(**_obs_fn_params)
7879
inv_obs_fn = self.build_inv_obs_fn(linear_decoder=linear_decoder, **_obs_fn_params)
79-
# Define the linear dynamics of the observable state space
80-
obs_state_dym = self.build_obs_dyn_module()
8180
# Variable holding the transfer operator used to evolve the observable state in time.
8281

8382
# Initialize the base class
@@ -127,10 +126,7 @@ def pre_process_obs_state(self,
127126
- obs_state_traj_aux: (batch, time, obs_state_dim) tensor.
128127
"""
129128
obs_state_traj = super().pre_process_obs_state(obs_state_traj)['obs_state_traj']
130-
if obs_state_traj_aux is not None:
131-
obs_state_traj_aux = super().pre_process_obs_state(obs_state_traj_aux)['obs_state_traj']
132-
else:
133-
obs_state_traj_aux = None
129+
obs_state_traj_aux = super().pre_process_obs_state(obs_state_traj_aux)['obs_state_traj']
134130
return dict(obs_state_traj=obs_state_traj, obs_state_traj_aux=obs_state_traj_aux)
135131

136132
def compute_loss_and_metrics(self,
@@ -241,22 +237,30 @@ def eval_metrics(self,
241237
pred_obs_state_trajs=pred_obs_state_traj,
242238
dt=self.dt,
243239
n_trajs_to_show=5)
240+
figs = dict(prediction=fig)
244241
if self.obs_state_dim == 3:
245-
fig_3ds = plot_system_3D(trajectories=state_traj, secondary_trajectories=pred_state_traj,
246-
title='state_traj', num_trajs_to_show=20)
247242
fig_3do = plot_system_3D(trajectories=obs_state_traj, secondary_trajectories=pred_obs_state_traj,
248243
title='obs_state', num_trajs_to_show=20)
249-
figs = dict(prediction=fig, state=fig_3ds, obs_state=fig_3do)
250-
elif self.obs_state_dim == 2:
251-
fig_2ds = plot_system_2D(trajs=state_traj, secondary_trajs=pred_state_traj, alpha=0.2,
252-
num_trajs_to_show=10)
244+
if obs_state_traj_aux is not None:
245+
fig_3do = plot_system_3D(trajectories=obs_state_traj_aux, legendgroup='aux', traj_colorscale='solar',
246+
num_trajs_to_show=20, fig=fig_3do)
247+
figs['obs_state'] = fig_3do
248+
if self.state_dim == 3:
249+
fig_3ds = plot_system_3D(trajectories=state_traj, secondary_trajectories=pred_state_traj,
250+
title='state_traj', num_trajs_to_show=20)
251+
figs['state'] = fig_3ds
252+
253+
if self.obs_state_dim == 2:
253254
fig_2do = plot_system_2D(trajs=obs_state_traj, secondary_trajs=pred_obs_state_traj, alpha=0.2,
254255
num_trajs_to_show=10)
255-
if self.aux_obs_space:
256-
plot_system_2D(trajs=obs_state_traj_aux, legendgroup='aux', num_trajs_to_show=10, fig=fig_2ds)
257-
figs = dict(prediction=fig, state=fig_2ds, obs_state=fig_2do)
258-
else:
259-
figs = dict(prediction=fig)
256+
if obs_state_traj_aux is not None:
257+
fig_2do = plot_system_2D(trajs=obs_state_traj_aux, legendgroup='aux',
258+
num_trajs_to_show=10, fig=fig_2do)
259+
figs['obs_state'] = fig_2do
260+
if self.state_dim == 2:
261+
fig_2ds = plot_system_2D(trajs=state_traj, secondary_trajs=pred_state_traj, alpha=0.2,
262+
num_trajs_to_show=10)
263+
figs['state'] = fig_2ds
260264

261265
metrics = None
262266
return figs, metrics
@@ -285,8 +289,6 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
285289
train_data[key] = torch.squeeze(value)
286290
else:
287291
torch.cat([train_data[key], torch.squeeze(value)], dim=0)
288-
for key, value in train_data.items():
289-
train_data[key] = value[:6]
290292

291293
# Apply any pre-processing to the state and next state
292294
state, next_state = train_data["state"], train_data["next_state"]
@@ -309,6 +311,8 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
309311
metrics['rank_obs_state'] = torch.linalg.matrix_rank(X).detach().to(torch.float)
310312

311313
if self.linear_decoder:
314+
# Predict the pre-processed state from the observable state
315+
# pre_state = self.pre_process_state(state)
312316
inv_projector, inv_projector_metrics = self.empirical_lin_inverse_projector(state, obs_state)
313317
metrics.update(inv_projector_metrics)
314318
self.inverse_projector = inv_projector
@@ -318,24 +322,14 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
318322
def build_obs_fn(self, num_layers, identity=False, **kwargs):
319323
if identity:
320324
return lambda x: (x, x)
325+
obs_fn = MLP(in_dim=self.state_dim, out_dim=self.obs_state_dim, num_layers=num_layers,
326+
head_with_activation=False, **kwargs)
327+
obs_fn_aux = None
328+
if self.aux_obs_space:
329+
obs_fn_aux = MLP(in_dim=self.state_dim, out_dim=self.obs_state_dim, num_layers=num_layers,
330+
head_with_activation=False, **kwargs)
321331

322-
num_backbone_layers = kwargs.pop('backbone_layers', num_layers - 2 if self.aux_obs_space else 0)
323-
if num_backbone_layers < 0:
324-
num_backbone_layers = num_layers - num_backbone_layers
325-
backbone_params = None
326-
if num_backbone_layers > 0 and self.aux_obs_space:
327-
backbone_feat_dim = kwargs.get('num_hidden_units')
328-
backbone_params = dict(in_dim=self.state_dim, out_dim=backbone_feat_dim,
329-
num_layers=num_backbone_layers, head_with_activation=True, **copy.copy(kwargs))
330-
kwargs['bias'] = False
331-
kwargs['batch_norm'] = False
332-
obs_fn_params = dict(in_dim=backbone_feat_dim, out_dim=self.obs_state_dim,
333-
num_layers=num_layers - num_backbone_layers, head_with_activation=False, **kwargs)
334-
else:
335-
obs_fn_params = dict(in_dim=self.state_dim, out_dim=self.obs_state_dim, num_layers=num_layers,
336-
head_with_activation=False, **kwargs)
337-
338-
return TwinMLP(net_kwargs=obs_fn_params, backbone_kwargs=backbone_params, fake_aux_fn=not self.aux_obs_space)
332+
return ObservableNet(obs_fn=obs_fn, obs_fn_aux=obs_fn_aux)
339333

340334
def build_inv_obs_fn(self, num_layers, linear_decoder: bool, identity=False, **kwargs):
341335
if identity:
@@ -352,7 +346,7 @@ def decoder(dpnet: DPNet, obs_state: Tensor):
352346
return MLP(in_dim=self.obs_state_dim, out_dim=self.state_dim, num_layers=num_layers, **kwargs)
353347

354348
def build_obs_dyn_module(self) -> MarkovDynamics:
355-
return LinearDynamics(state_dim=self.obs_state_dim, dt=self.dt)
349+
return LinearDynamics(state_dim=self.obs_state_dim, dt=self.dt, trainable=False)
356350

357351
def empirical_lin_inverse_projector(self, state: Tensor, obs_state: Tensor):
358352
""" Compute the empirical inverse projector from the observable state to the pre-processed state.

nn/TwinMLP.py nn/ObservableNet.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def forward(self, input):
4141
return output1, output2
4242

4343
def get_hparams(self):
44-
if self.fake_aux_fn:
45-
return dict(fn1=self.fn1.get_hparams())
46-
else:
47-
return dict(fn1=self.fn1.get_hparams(), fn2=self.fn2.get_hparams())
44+
return {}
4845

4946

0 commit comments

Comments
 (0)