16
16
17
17
from data .DynamicsDataModule import DynamicsDataModule
18
18
from nn .LinearDynamics import LinearDynamics
19
- from nn .TwinMLP import TwinMLP
19
+ from nn .ObservableNet import ObservableNet
20
20
from nn .latent_markov_dynamics import LatentMarkovDynamics
21
21
from nn .markov_dynamics import MarkovDynamics
22
22
from nn .mlp import MLP
@@ -42,7 +42,6 @@ class DPNet(LatentMarkovDynamics):
42
42
batch_norm = True ,
43
43
bias = False ,
44
44
init_mode = 'fan_in' ,
45
- backbone_layers = - 2 # num_layers - 2
46
45
)
47
46
48
47
def __init__ (
@@ -73,11 +72,11 @@ def __init__(
73
72
_obs_fn_params = self ._default_obs_fn_params .copy ()
74
73
if obs_fn_params is not None :
75
74
_obs_fn_params .update (obs_fn_params )
75
+ # Define the linear dynamics of the observable state space
76
+ obs_state_dym = self .build_obs_dyn_module ()
76
77
# Build the observation function and its inverse
77
78
obs_fn = self .build_obs_fn (** _obs_fn_params )
78
79
inv_obs_fn = self .build_inv_obs_fn (linear_decoder = linear_decoder , ** _obs_fn_params )
79
- # Define the linear dynamics of the observable state space
80
- obs_state_dym = self .build_obs_dyn_module ()
81
80
# Variable holding the transfer operator used to evolve the observable state in time.
82
81
83
82
# Initialize the base class
@@ -127,10 +126,7 @@ def pre_process_obs_state(self,
127
126
- obs_state_traj_aux: (batch, time, obs_state_dim) tensor.
128
127
"""
129
128
obs_state_traj = super ().pre_process_obs_state (obs_state_traj )['obs_state_traj' ]
130
- if obs_state_traj_aux is not None :
131
- obs_state_traj_aux = super ().pre_process_obs_state (obs_state_traj_aux )['obs_state_traj' ]
132
- else :
133
- obs_state_traj_aux = None
129
+ obs_state_traj_aux = super ().pre_process_obs_state (obs_state_traj_aux )['obs_state_traj' ]
134
130
return dict (obs_state_traj = obs_state_traj , obs_state_traj_aux = obs_state_traj_aux )
135
131
136
132
def compute_loss_and_metrics (self ,
@@ -241,22 +237,30 @@ def eval_metrics(self,
241
237
pred_obs_state_trajs = pred_obs_state_traj ,
242
238
dt = self .dt ,
243
239
n_trajs_to_show = 5 )
240
+ figs = dict (prediction = fig )
244
241
if self .obs_state_dim == 3 :
245
- fig_3ds = plot_system_3D (trajectories = state_traj , secondary_trajectories = pred_state_traj ,
246
- title = 'state_traj' , num_trajs_to_show = 20 )
247
242
fig_3do = plot_system_3D (trajectories = obs_state_traj , secondary_trajectories = pred_obs_state_traj ,
248
243
title = 'obs_state' , num_trajs_to_show = 20 )
249
- figs = dict (prediction = fig , state = fig_3ds , obs_state = fig_3do )
250
- elif self .obs_state_dim == 2 :
251
- fig_2ds = plot_system_2D (trajs = state_traj , secondary_trajs = pred_state_traj , alpha = 0.2 ,
252
- num_trajs_to_show = 10 )
244
+ if obs_state_traj_aux is not None :
245
+ fig_3do = plot_system_3D (trajectories = obs_state_traj_aux , legendgroup = 'aux' , traj_colorscale = 'solar' ,
246
+ num_trajs_to_show = 20 , fig = fig_3do )
247
+ figs ['obs_state' ] = fig_3do
248
+ if self .state_dim == 3 :
249
+ fig_3ds = plot_system_3D (trajectories = state_traj , secondary_trajectories = pred_state_traj ,
250
+ title = 'state_traj' , num_trajs_to_show = 20 )
251
+ figs ['state' ] = fig_3ds
252
+
253
+ if self .obs_state_dim == 2 :
253
254
fig_2do = plot_system_2D (trajs = obs_state_traj , secondary_trajs = pred_obs_state_traj , alpha = 0.2 ,
254
255
num_trajs_to_show = 10 )
255
- if self .aux_obs_space :
256
- plot_system_2D (trajs = obs_state_traj_aux , legendgroup = 'aux' , num_trajs_to_show = 10 , fig = fig_2ds )
257
- figs = dict (prediction = fig , state = fig_2ds , obs_state = fig_2do )
258
- else :
259
- figs = dict (prediction = fig )
256
+ if obs_state_traj_aux is not None :
257
+ fig_2do = plot_system_2D (trajs = obs_state_traj_aux , legendgroup = 'aux' ,
258
+ num_trajs_to_show = 10 , fig = fig_2do )
259
+ figs ['obs_state' ] = fig_2do
260
+ if self .state_dim == 2 :
261
+ fig_2ds = plot_system_2D (trajs = state_traj , secondary_trajs = pred_state_traj , alpha = 0.2 ,
262
+ num_trajs_to_show = 10 )
263
+ figs ['state' ] = fig_2ds
260
264
261
265
metrics = None
262
266
return figs , metrics
@@ -285,8 +289,6 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
285
289
train_data [key ] = torch .squeeze (value )
286
290
else :
287
291
torch .cat ([train_data [key ], torch .squeeze (value )], dim = 0 )
288
- for key , value in train_data .items ():
289
- train_data [key ] = value [:6 ]
290
292
291
293
# Apply any pre-processing to the state and next state
292
294
state , next_state = train_data ["state" ], train_data ["next_state" ]
@@ -309,6 +311,8 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
309
311
metrics ['rank_obs_state' ] = torch .linalg .matrix_rank (X ).detach ().to (torch .float )
310
312
311
313
if self .linear_decoder :
314
+ # Predict the pre-processed state from the observable state
315
+ # pre_state = self.pre_process_state(state)
312
316
inv_projector , inv_projector_metrics = self .empirical_lin_inverse_projector (state , obs_state )
313
317
metrics .update (inv_projector_metrics )
314
318
self .inverse_projector = inv_projector
@@ -318,24 +322,14 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
318
322
def build_obs_fn (self , num_layers , identity = False , ** kwargs ):
319
323
if identity :
320
324
return lambda x : (x , x )
325
+ obs_fn = MLP (in_dim = self .state_dim , out_dim = self .obs_state_dim , num_layers = num_layers ,
326
+ head_with_activation = False , ** kwargs )
327
+ obs_fn_aux = None
328
+ if self .aux_obs_space :
329
+ obs_fn_aux = MLP (in_dim = self .state_dim , out_dim = self .obs_state_dim , num_layers = num_layers ,
330
+ head_with_activation = False , ** kwargs )
321
331
322
- num_backbone_layers = kwargs .pop ('backbone_layers' , num_layers - 2 if self .aux_obs_space else 0 )
323
- if num_backbone_layers < 0 :
324
- num_backbone_layers = num_layers - num_backbone_layers
325
- backbone_params = None
326
- if num_backbone_layers > 0 and self .aux_obs_space :
327
- backbone_feat_dim = kwargs .get ('num_hidden_units' )
328
- backbone_params = dict (in_dim = self .state_dim , out_dim = backbone_feat_dim ,
329
- num_layers = num_backbone_layers , head_with_activation = True , ** copy .copy (kwargs ))
330
- kwargs ['bias' ] = False
331
- kwargs ['batch_norm' ] = False
332
- obs_fn_params = dict (in_dim = backbone_feat_dim , out_dim = self .obs_state_dim ,
333
- num_layers = num_layers - num_backbone_layers , head_with_activation = False , ** kwargs )
334
- else :
335
- obs_fn_params = dict (in_dim = self .state_dim , out_dim = self .obs_state_dim , num_layers = num_layers ,
336
- head_with_activation = False , ** kwargs )
337
-
338
- return TwinMLP (net_kwargs = obs_fn_params , backbone_kwargs = backbone_params , fake_aux_fn = not self .aux_obs_space )
332
+ return ObservableNet (obs_fn = obs_fn , obs_fn_aux = obs_fn_aux )
339
333
340
334
def build_inv_obs_fn (self , num_layers , linear_decoder : bool , identity = False , ** kwargs ):
341
335
if identity :
@@ -352,7 +346,7 @@ def decoder(dpnet: DPNet, obs_state: Tensor):
352
346
return MLP (in_dim = self .obs_state_dim , out_dim = self .state_dim , num_layers = num_layers , ** kwargs )
353
347
354
348
def build_obs_dyn_module (self ) -> MarkovDynamics :
355
- return LinearDynamics (state_dim = self .obs_state_dim , dt = self .dt )
349
+ return LinearDynamics (state_dim = self .obs_state_dim , dt = self .dt , trainable = False )
356
350
357
351
def empirical_lin_inverse_projector (self , state : Tensor , obs_state : Tensor ):
358
352
""" Compute the empirical inverse projector from the observable state to the pre-processed state.
0 commit comments