Skip to content

Commit 20be9c1

Browse files
committed
ObsNetwork shared encoder and general small fixes
- Switch to MorphoSymm EMLP / MLP networks. - Reduce the number of plots during training, faster training.
1 parent b8cd185 commit 20be9c1

File tree

67 files changed

+301
-516
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+301
-516
lines changed

cfg/config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ device: 0
1313
num_workers: 0 # Dataloader workers
1414
debug: False
1515
debug_loops: False
16-
max_epochs: 200
16+
max_epochs: 150
1717

1818
# Markov Dynamics Model
1919
model: dpnet
@@ -37,7 +37,7 @@ hydra:
3737
chdir: True
3838
env_set:
3939
XLA_PYTHON_CLIENT_PREALLOCATE: 'false'
40-
HYDRA_FULL_ERROR: 1
40+
HYDRA_FULL_ERROR: '1'
4141
# CUDA_VISIBLE_DEVICES: ${.device}
4242
config:
4343
override_dirname:
@@ -59,7 +59,7 @@ hydra:
5959
- model.lr
6060
- model.batch_norm
6161
- model.bias
62-
- model.aux_obs_space
62+
- model.explicit_transfer_op
6363
- model.use_spectral_score
6464
- system.state_dim # Metrics included in system.summary
6565
- system.obs_state_dim

cfg/model/dae.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ defaults:
44
name: DAE
55
# Model hyperparameters
66
obs_pred_w: 1.0 # Cost function weight for prediction in observation space Z
7-
orth_w: 0.0 # Weight of the orthonormal regularization term in the loss function
7+
orth_w: 0.1 # Weight of the orthonormal regularization term in the loss function
88
corr_w: 0.0
99

1010
# Optimization hyperparameters parameters
11-
lr: 5e-4
11+
lr: 1e-3
1212
equivariant: False
1313

1414
summary: ${model.name}-Obs_w:${model.obs_pred_w}-Orth_w:${model.orth_w}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}

cfg/model/dpnet.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ lr: 1e-3
1717
max_ck_window_length: ${system.pred_horizon} # Maximum length of the Chapman-Kolmogorov window
1818
ck_w: 0.0 # Weight of the Chapman-Kolmogorov regularization term in the loss function
1919
orth_w: 1.0 # Weight of the orthonormal regularization term in the loss function
20-
aux_obs_space: True # Whether to use an auxiliary observable space.
20+
explicit_transfer_op: True # Whether to use a shared encoder network for the computation of the observations in H and H' (True) or to use two separate encoder networks (False)
2121
use_spectral_score: True # Whether to use the spectral or the correlation score
2222

23-
summary: ${model.name}-CK_w:${model.ck_w}-Orth_w:${model.orth_w}-Win:${model.max_ck_window_length}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}-AOS:${model.aux_obs_space}-SS:${model.use_spectral_score}
23+
summary: ${model.name}-CK_w:${model.ck_w}-Orth_w:${model.orth_w}-Win:${model.max_ck_window_length}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}-ETO:${model.explicit_transfer_op}-SS:${model.use_spectral_score}
2424

cfg/system/linear_system.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ name: 'linear_system'
66
state_dim: 10 # Dimension of the system's state
77
obs_state_dim: 30 # Dimension of the system's observable state
88

9-
frames_per_state: 1 # Number of time-frames to use as a Markov Process state time step
10-
pred_horizon: 10 # Number (or percentage) of Markov Process state time steps to predict into the future
11-
eval_pred_horizon: .5 # Number (or percentage) of Markov Process state time steps to predict into the future
12-
9+
frames_per_state: 1 # Number of time-frames to use as a Markov Process state time step
10+
pred_horizon: 10 # Number (or percentage) of Markov Process state time steps to predict into the future
11+
eval_pred_horizon: ${system.pred_horizon} # Number (or percentage) of Markov Process state time steps to predict into the future
12+
test_pred_horizon: .5
1313
group: C10
1414
noise_level: 2
1515
n_constraints: 0

cfg/system/mini_cheetah.yaml

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@ defaults:
33

44
name: 'mini_cheetah'
55

6-
obs_state_dim: 64 # Dimension of the system's observable state
6+
obs_state_dim: 32 # Dimension of the system's observable state
77

88
state_obs: ['q']
99
action_obs: []
1010

1111
# dt = 0.0012 s this is the average delta time between observations.
12-
frames_per_state: 1 # Number of time-frames to use as a Markov Process state time step
13-
pred_horizon: 50 # Number (or percentage) of Markov Process state time steps to predict into the future
14-
eval_pred_horizon: 200 # Number (or percentage) of Markov Process state time steps to predict into the future
12+
frames_per_state: 1 # Number of time-frames to use as a Markov Process state time step
13+
pred_horizon: 10 # Number (or percentage) of Markov Process state time steps to predict into the future
14+
eval_pred_horizon: ${system.pred_horizon} # Number (or percentage) of Markov Process state time steps to predict into the future
15+
test_pred_horizon: 200
1516

1617
dynamic_mode: grass
1718

data/DynamicsDataModule.py

+60-59
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import time
33
from pathlib import Path
4-
from typing import Optional, Union
4+
from typing import Any, Optional, Union
55

66
import escnn.group
77
import numpy as np
@@ -26,6 +26,7 @@ def __init__(self,
2626
data_path: Path,
2727
pred_horizon: Union[int, float] = 0.25,
2828
eval_pred_horizon: Union[int, float] = 0.5,
29+
test_pred_horizon: Union[int, float] = 0.5,
2930
system_cfg: Optional[dict] = None,
3031
batch_size: int = 256,
3132
frames_per_step: int = 1,
@@ -50,6 +51,7 @@ def __init__(self,
5051
assert pred_horizon >= 1, "At least we need to forecast a single dynamics step"
5152
self.pred_horizon = pred_horizon
5253
self.eval_pred_horizon = eval_pred_horizon
54+
self.test_pred_horizon = test_pred_horizon
5355
self.batch_size = batch_size
5456
self.num_workers = num_workers
5557
# Metadata and dynamics information
@@ -63,12 +65,14 @@ def __init__(self,
6365
self.gspace = None
6466
self.state_field_type, self.action_field_type = None, None
6567
self.device = device
68+
self.train_dataset, self.val_dataset, self.test_dataset = None, None, None
6669

6770
def prepare_data(self):
6871

6972
if self.prepared:
70-
self._train_dataset = self._train_dataset.shuffle(buffer_size=self._train_dataset.dataset_size / 2)
71-
self._train_dataloader = DataLoader(dataset=self._train_dataset, batch_size=self.batch_size,
73+
self.train_dataset = self.train_dataset.shuffle(
74+
buffer_size=min(self.train_dataset.dataset_size // 4, 5000))
75+
self._train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=self.batch_size,
7276
num_workers=self.num_workers,
7377
persistent_workers=True if self.num_workers > 0 else False,
7478
collate_fn=self.data_augmentation_collate_fn if self.augment else
@@ -79,7 +83,6 @@ def prepare_data(self):
7983
start_time = time.time()
8084
log.info(f"Preparing datasets {self._data_path}")
8185

82-
a = list(self._data_path.rglob('*train.pkl'))
8386
dyn_sys_data = set([a.parent for a in list(self._data_path.rglob('*train.pkl'))])
8487
if self.noise_level is not None:
8588
system_data_path = [path for path in dyn_sys_data if f"noise_level={self.noise_level}" in str(path)]
@@ -99,6 +102,7 @@ def prepare_data(self):
99102
val_shards=val_data,
100103
train_pred_horizon=self.pred_horizon,
101104
eval_pred_horizon=self.eval_pred_horizon,
105+
test_pred_horizon=self.test_pred_horizon,
102106
frames_per_step=self.frames_per_step,
103107
state_obs=self.state_obs,
104108
action_obs=self.action_obs)
@@ -118,20 +122,33 @@ def prepare_data(self):
118122
test_dataset = test_dataset.shuffle(buffer_size=min(train_dataset.dataset_size // 4, 1000), seed=18)
119123
val_dataset = val_dataset.shuffle(buffer_size=min(train_dataset.dataset_size // 4, 1000), seed=18)
120124
# Convert to torch. Apply map to get samples containing state and next state
121-
obs_to_remove = train_dataset.features.keys()
122-
train_dataset = train_dataset.with_format("torch").map(
125+
# After mapping to state next state, remove all other observations
126+
obs_to_remove = set(train_dataset.features.keys())
127+
obs_to_remove.discard('state')
128+
self.train_dataset = train_dataset.with_format("torch").map(
123129
DynamicsRecording.map_state_next_state, batched=True, fn_kwargs={'state_observations': self.state_obs},
124130
remove_columns=tuple(obs_to_remove))
125-
test_dataset = test_dataset.with_format("torch").map(
131+
self.test_dataset = test_dataset.with_format("torch").map(
126132
DynamicsRecording.map_state_next_state, batched=True, fn_kwargs={'state_observations': self.state_obs},
127133
remove_columns=tuple(obs_to_remove))
128-
val_dataset = val_dataset.with_format("torch").map(
134+
self.val_dataset = val_dataset.with_format("torch").map(
129135
DynamicsRecording.map_state_next_state, batched=True, fn_kwargs={'state_observations': self.state_obs},
130136
remove_columns=tuple(obs_to_remove))
131137

132-
self._train_dataset = train_dataset
133-
self._test_dataset = test_dataset
134-
self._val_dataset = val_dataset
138+
# Configure the prediction dataloader for the approximating and evaluating the transfer operator. This will
139+
# be a dataloader passing state and next state single step measurements:
140+
datasets, dynamics_recording = get_dynamics_dataset(train_shards=train_data,
141+
test_shards=None,
142+
val_shards=None,
143+
train_pred_horizon=1,
144+
eval_pred_horizon=1,
145+
frames_per_step=self.frames_per_step,
146+
state_obs=self.state_obs,
147+
action_obs=self.action_obs)
148+
transfer_op_train_dataset, _, _ = datasets
149+
self._transfer_op_train_dataset = transfer_op_train_dataset.with_format("torch").map(
150+
DynamicsRecording.map_state_next_state, batched=True, fn_kwargs={'state_observations': self.state_obs},
151+
remove_columns=tuple(obs_to_remove))
135152

136153
# Rebuilt the ESCNN representations of measurements _________________________________________________________
137154
# TODO: Handle dyn systems without symmetries
@@ -157,45 +174,11 @@ def prepare_data(self):
157174
action_reps = [rep for frame_reps in action_reps for rep in frame_reps] # flatten list of reps
158175
self.action_field_type = FieldType(self.gspace, representations=action_reps)
159176

160-
self._train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size,
161-
num_workers=self.num_workers,
162-
persistent_workers=True if self.num_workers > 0 else False,
163-
collate_fn=self.data_augmentation_collate_fn if self.augment else
164-
self.collate_fn)
165-
batch_size = min(self.batch_size, test_dataset.dataset_size)
166-
self._test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False,
167-
persistent_workers=True if self.num_workers > 0 else False,
168-
num_workers=self.num_workers,
169-
# pin_memory=True,
170-
collate_fn=self.collate_fn)
171-
batch_size = min(self.batch_size, test_dataset.dataset_size)
172-
self._val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False,
173-
persistent_workers=True if self.num_workers > 0 else False,
174-
num_workers=self.num_workers,
175-
# pin_memory=True,
176-
collate_fn=self.collate_fn)
177-
178-
# Configure the prediction dataloader for the approximating and evaluating the transfer operator. This will
179-
# be a dataloader passing state and next state single step measurements:
180-
datasets, dynamics_recording = get_dynamics_dataset(train_shards=train_data,
181-
test_shards=test_data,
182-
val_shards=val_data,
183-
train_pred_horizon=1,
184-
eval_pred_horizon=1,
185-
frames_per_step=self.frames_per_step,
186-
state_obs=self.state_obs,
187-
action_obs=self.action_obs)
188-
transfer_op_train_dataset, _, _ = datasets
189-
transfer_op_train_dataset = transfer_op_train_dataset.with_format("torch").map(
190-
DynamicsRecording.map_state_next_state, batched=True, fn_kwargs={'state_observations': self.state_obs})
191-
192-
self._trans_op_dataloader = DataLoader(dataset=transfer_op_train_dataset,
193-
batch_size=transfer_op_train_dataset.dataset_size, # Single batch
194-
pin_memory=False, num_workers=self.num_workers, shuffle=False,
195-
collate_fn=self.collate_fn)
196-
197177
log.info(f"Data preparation done in {time.time() - start_time:.2f} [s]")
198178

179+
def setup(self, stage: str) -> None:
180+
log.info(f"Setting up {stage} dataset")
181+
199182
def compute_loss_metrics(self, predictions: dict, inputs: dict) -> (torch.Tensor, dict):
200183
"""
201184
Compute the loss and metrics from the predictions and inputs
@@ -206,20 +189,35 @@ def compute_loss_metrics(self, predictions: dict, inputs: dict) -> (torch.Tensor
206189
raise NotImplementedError("Implement this function in the derived class")
207190

208191
def train_dataloader(self):
209-
return self._train_dataloader
192+
return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size,
193+
num_workers=self.num_workers,
194+
collate_fn=self.collate_fn,
195+
persistent_workers=True if self.num_workers > 0 else False, drop_last=False)
210196

211197
def val_dataloader(self):
212-
return self._val_dataloader
198+
return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, shuffle=False,
199+
persistent_workers=True if self.num_workers > 0 else False,
200+
num_workers=self.num_workers,
201+
collate_fn=self.collate_fn,
202+
drop_last=False)
213203

214204
def test_dataloader(self):
215-
return self._test_dataloader
205+
return DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False,
206+
persistent_workers=True if self.num_workers > 0 else False,
207+
collate_fn=self.collate_fn,
208+
num_workers=self.num_workers, drop_last=False)
216209

217210
def predict_dataloader(self):
218-
return self._trans_op_dataloader
211+
return DataLoader(dataset=self._transfer_op_train_dataset, batch_size=self.batch_size, pin_memory=False,
212+
collate_fn=self.collate_fn,
213+
num_workers=self.num_workers, shuffle=False, drop_last=False)
214+
215+
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
216+
return super().transfer_batch_to_device(batch, device, dataloader_idx)
219217

220218
@property
221219
def prepared(self):
222-
return self._train_dataloader is not None
220+
return self.train_dataset is not None
223221

224222
def collate_fn(self, batch_list: list) -> dict:
225223
batch = torch.utils.data.default_collate(batch_list)
@@ -260,8 +258,8 @@ def plot_sample_trajs(self):
260258
num_trajs = 5
261259
fig = None
262260
styles = {'Train': dict(width=3, dash='solid'),
263-
'Test': dict(width=2, dash='2px'),
264-
'Val': dict(width=1, dash='5px')}
261+
'Test': dict(width=2, dash='2px'),
262+
'Val': dict(width=1, dash='5px')}
265263
for partition, dataloader in zip(['Train', 'Test', 'Val'],
266264
[self.train_dataloader(), self.test_dataloader(), self.val_dataloader()]):
267265
batch = next(iter(dataloader))
@@ -280,25 +278,28 @@ def plot_sample_trajs(self):
280278

281279
# Find all dynamic systems recordings
282280
path_to_data /= Path('mini_cheetah') / 'recordings' / 'grass'
281+
# path_to_data = Path('/home/danfoa/Projects/koopman_robotics/data/linear_system/group=C10-dim=10/n_constraints=0
282+
# /f_time_constant=1.5[s]-frames=200-horizon=8.7[s]/noise_level=0')
283283
path_to_dyn_sys_data = set([a.parent for a in list(path_to_data.rglob('*train.pkl'))])
284284
# Select a dynamical system
285285
mock_path = path_to_dyn_sys_data.pop()
286286

287287
data_module = DynamicsDataModule(data_path=mock_path,
288288
pred_horizon=10,
289-
eval_pred_horizon=100,
289+
eval_pred_horizon=200,
290290
frames_per_step=1,
291291
num_workers=1,
292292
batch_size=1000,
293293
augment=False,
294-
state_obs=('q',),
295-
action_obs=tuple(),
294+
state_obs=('v'),
295+
# action_obs=tuple(),
296296
)
297297

298298
# Test loading of the DynamicsRecording
299299
data_module.prepare_data()
300+
s = next(iter(data_module.train_dataset))
301+
300302
data_module.plot_sample_trajs()
301-
s = next(iter(data_module._train_dataset))
302303
states, state_trajs = None, None
303304
fig = None
304305

0 commit comments

Comments
 (0)