Skip to content

Commit f282b20

Browse files
committed
Tests on MiniCheetah dataset
- This improves dynamics recordings efficiency when the dataset contains multiple observations that might not be used in state and action. - Fixed memory exponential growth due to innecesary copying on the IterableDataset generator. Now we only return views of the true data, until the batched sample is going to be generated. - Added plotting capabilities for data module, to visualize train, test, and validation samples.
1 parent 76cf315 commit f282b20

File tree

70 files changed

+458
-238
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+458
-238
lines changed

cfg/model/dae.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ defaults:
33

44
name: DAE
55
# Model hyperparameters
6-
obs_pred_w: 15.0 # Cost function weight for prediction in observation space Z
7-
orth_w: 1.0 # Weight of the orthonormal regularization term in the loss function
6+
obs_pred_w: 1.0 # Cost function weight for prediction in observation space Z
7+
orth_w: 0.0 # Weight of the orthonormal regularization term in the loss function
88
corr_w: 0.0
99

1010
# Optimization hyperparameters parameters
11-
lr: 1e-3
11+
lr: 5e-4
1212
equivariant: False
1313

1414
summary: ${model.name}-Obs_w:${model.obs_pred_w}-Orth_w:${model.orth_w}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}

cfg/model/dpnet.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ equivariant: False
99
# Model hyperparameters
1010
batch_norm: True
1111
bias: False
12+
1213
# Optimization hyperparameters parameters
1314
lr: 1e-3
1415

16+
# Model hyperparameters
1517
max_ck_window_length: ${system.pred_horizon} # Maximum length of the Chapman-Kolmogorov window
1618
ck_w: 0.0 # Weight of the Chapman-Kolmogorov regularization term in the loss function
17-
orth_w: 0.5 # Weight of the orthonormal regularization term in the loss function
18-
aux_obs_space: False # Whether to use an auxiliary observable space.
19+
orth_w: 1.0 # Weight of the orthonormal regularization term in the loss function
20+
aux_obs_space: True # Whether to use an auxiliary observable space.
1921
use_spectral_score: True # Whether to use the spectral or the correlation score
2022

2123
summary: ${model.name}-CK_w:${model.ck_w}-Orth_w:${model.orth_w}-Win:${model.max_ck_window_length}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}-AOS:${model.aux_obs_space}-SS:${model.use_spectral_score}

cfg/model/edae.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@ defaults:
33

44
name: E-DAE
55
# Model hyperparameters
6+
obs_pred_w: 1.0 # Cost function weight for prediction in observation space Z
7+
orth_w: 0.0 # Weight of the orthonormal regularization term in the loss function
8+
corr_w: 0.5
9+
610
state_dependent_obs_dyn: False # Whether to use state-dependent observation dynamics
711
group_avg_trick: True
812

9-
13+
# Optimization hyperparameters parameters
14+
lr: 5e-4
1015
equivariant: True
1116

cfg/model/edpnet.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@ defaults:
33

44
name: E-DPNet
55

6+
# Optimization hyperparameters parameters
7+
lr: 1e-3
8+
69
# Symmetry exploitation parameters
710
equivariant: True
11+
12+
# Model hyperparameters
813
group_avg_trick: True
14+
ck_w: 0.0
915

10-
activation: ELU

cfg/system/mini_cheetah.yaml

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
defaults:
2+
- base_system
3+
4+
name: 'mini_cheetah'
5+
6+
obs_state_dim: 64 # Dimension of the system's observable state
7+
8+
state_obs: ['q']
9+
action_obs: []
10+
11+
# dt = 0.0012 s this is the average delta time between observations.
12+
frames_per_state: 1 # Number of time-frames to use as a Markov Process state time step
13+
pred_horizon: 50 # Number (or percentage) of Markov Process state time steps to predict into the future
14+
eval_pred_horizon: 200 # Number (or percentage) of Markov Process state time steps to predict into the future
15+
16+
dynamic_mode: grass
17+
18+
data_path: '${system.name}/recordings/${system.dynamic_mode}'
19+
20+
summary: S:${system.dynamic_mode}-OS:${system.obs_state_dim}-H:${system.pred_horizon}-EH:${system.eval_pred_horizon}

data/DynamicsDataModule.py

+86-73
Large diffs are not rendered by default.

data/DynamicsRecording.py

+176-156
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)