21
21
from data .DynamicsDataModule import DynamicsDataModule
22
22
from nn .DeepProjections import DPNet
23
23
from nn .EquivLinearDynamics import EquivLinearDynamics
24
- from nn .TwinMLP import TwinMLP
24
+ from nn .ObservableNet import ObservableNet
25
25
from nn .emlp import EMLP
26
26
from nn .markov_dynamics import MarkovDynamics
27
27
from utils .losses_and_metrics import forecasting_loss_and_metrics , obs_state_space_metrics
@@ -38,7 +38,7 @@ class EquivDPNet(DPNet):
38
38
activation = "p_elu" ,
39
39
batch_norm = True ,
40
40
bias = False ,
41
- backbone_layers = - 2 # num_layers - 2
41
+ # backbone_layers=-2 # num_layers - 2
42
42
)
43
43
44
44
def __init__ (self ,
@@ -257,40 +257,27 @@ def empirical_lin_inverse_projector(self, state: Tensor, obs_state: Tensor):
257
257
return A , metrics
258
258
259
259
def build_obs_fn (self , num_layers , ** kwargs ):
260
- num_backbone_layers = kwargs .pop ('backbone_layers' , num_layers - 2 if self .aux_obs_space else 0 )
261
- if num_backbone_layers < 0 :
262
- num_backbone_layers = num_layers - num_backbone_layers
263
- backbone_params = None
264
- if num_backbone_layers > 0 and self .aux_obs_space :
265
- num_hidden_units = kwargs .get ('num_hidden_units' )
266
- activation_type = kwargs .pop ('activation' )
267
- num_hidden_regular_fields = int (np .ceil (num_hidden_units // self .state_type_iso .size ))
268
- act = EMLP .get_activation (activation = activation_type ,
269
- in_type = self .state_type_iso ,
270
- channels = num_hidden_regular_fields )
271
- backbone_params = dict (in_type = self .state_type_iso ,
272
- out_type = act .out_type ,
273
- activation = act ,
274
- num_layers = num_backbone_layers ,
275
- head_with_activation = True ,
276
- ** copy .copy (kwargs ))
277
- kwargs ['bias' ] = False
278
- kwargs ['batch_norm' ] = False
279
- obs_fn_params = dict (in_type = act .out_type , out_type = self .obs_state_type ,
280
- num_layers = num_layers - num_backbone_layers ,
281
- activation = act ,
282
- head_with_activation = False , ** kwargs )
283
- else :
284
- obs_fn_params = dict (in_type = self .state_type_iso ,
285
- out_type = self .obs_state_type ,
286
- num_layers = num_layers ,
287
- head_with_activation = False ,
288
- ** kwargs )
289
260
290
- return TwinMLP (net_kwargs = obs_fn_params ,
291
- backbone_kwargs = backbone_params ,
292
- fake_aux_fn = not self .aux_obs_space ,
293
- equivariant = True )
261
+ num_hidden_units = kwargs .get ('num_hidden_units' )
262
+ activation_type = kwargs .pop ('activation' )
263
+ act = EMLP .get_activation (activation = activation_type ,
264
+ in_type = self .state_type_iso ,
265
+ desired_hidden_units = num_hidden_units )
266
+
267
+ obs_fn = EMLP (in_type = self .state_type_iso ,
268
+ out_type = self .obs_state_type ,
269
+ num_layers = num_layers ,
270
+ activation = act ,
271
+ ** kwargs )
272
+ obs_fn_aux = None
273
+ if self .aux_obs_space :
274
+ obs_fn_aux = EMLP (in_type = self .state_type_iso ,
275
+ out_type = self .obs_state_type ,
276
+ num_layers = num_layers ,
277
+ activation = act ,
278
+ ** kwargs )
279
+
280
+ return ObservableNet (obs_fn = obs_fn , obs_fn_aux = obs_fn_aux )
294
281
295
282
def build_inv_obs_fn (self , num_layers , linear_decoder : bool , ** kwargs ):
296
283
if linear_decoder :
@@ -307,7 +294,7 @@ def decoder(dpnet: DPNet, obs_state: Tensor):
307
294
** kwargs )
308
295
309
296
def build_obs_dyn_module (self ) -> MarkovDynamics :
310
- return EquivLinearDynamics (state_rep = self .obs_state_type . representation ,
297
+ return EquivLinearDynamics (state_type = self .obs_state_type ,
311
298
dt = self .dt ,
312
299
trainable = False ,
313
300
group_avg_trick = self .group_avg_trick )
0 commit comments