1
1
import logging
2
2
import time
3
3
from pathlib import Path
4
- from typing import Optional , Union
4
+ from typing import Any , Optional , Union
5
5
6
6
import escnn .group
7
7
import numpy as np
@@ -26,6 +26,7 @@ def __init__(self,
26
26
data_path : Path ,
27
27
pred_horizon : Union [int , float ] = 0.25 ,
28
28
eval_pred_horizon : Union [int , float ] = 0.5 ,
29
+ test_pred_horizon : Union [int , float ] = 0.5 ,
29
30
system_cfg : Optional [dict ] = None ,
30
31
batch_size : int = 256 ,
31
32
frames_per_step : int = 1 ,
@@ -50,6 +51,7 @@ def __init__(self,
50
51
assert pred_horizon >= 1 , "At least we need to forecast a single dynamics step"
51
52
self .pred_horizon = pred_horizon
52
53
self .eval_pred_horizon = eval_pred_horizon
54
+ self .test_pred_horizon = test_pred_horizon
53
55
self .batch_size = batch_size
54
56
self .num_workers = num_workers
55
57
# Metadata and dynamics information
@@ -63,12 +65,14 @@ def __init__(self,
63
65
self .gspace = None
64
66
self .state_field_type , self .action_field_type = None , None
65
67
self .device = device
68
+ self .train_dataset , self .val_dataset , self .test_dataset = None , None , None
66
69
67
70
def prepare_data (self ):
68
71
69
72
if self .prepared :
70
- self ._train_dataset = self ._train_dataset .shuffle (buffer_size = self ._train_dataset .dataset_size / 2 )
71
- self ._train_dataloader = DataLoader (dataset = self ._train_dataset , batch_size = self .batch_size ,
73
+ self .train_dataset = self .train_dataset .shuffle (
74
+ buffer_size = min (self .train_dataset .dataset_size // 4 , 5000 ))
75
+ self ._train_dataloader = DataLoader (dataset = self .train_dataset , batch_size = self .batch_size ,
72
76
num_workers = self .num_workers ,
73
77
persistent_workers = True if self .num_workers > 0 else False ,
74
78
collate_fn = self .data_augmentation_collate_fn if self .augment else
@@ -79,7 +83,6 @@ def prepare_data(self):
79
83
start_time = time .time ()
80
84
log .info (f"Preparing datasets { self ._data_path } " )
81
85
82
- a = list (self ._data_path .rglob ('*train.pkl' ))
83
86
dyn_sys_data = set ([a .parent for a in list (self ._data_path .rglob ('*train.pkl' ))])
84
87
if self .noise_level is not None :
85
88
system_data_path = [path for path in dyn_sys_data if f"noise_level={ self .noise_level } " in str (path )]
@@ -99,6 +102,7 @@ def prepare_data(self):
99
102
val_shards = val_data ,
100
103
train_pred_horizon = self .pred_horizon ,
101
104
eval_pred_horizon = self .eval_pred_horizon ,
105
+ test_pred_horizon = self .test_pred_horizon ,
102
106
frames_per_step = self .frames_per_step ,
103
107
state_obs = self .state_obs ,
104
108
action_obs = self .action_obs )
@@ -118,20 +122,33 @@ def prepare_data(self):
118
122
test_dataset = test_dataset .shuffle (buffer_size = min (train_dataset .dataset_size // 4 , 1000 ), seed = 18 )
119
123
val_dataset = val_dataset .shuffle (buffer_size = min (train_dataset .dataset_size // 4 , 1000 ), seed = 18 )
120
124
# Convert to torch. Apply map to get samples containing state and next state
121
- obs_to_remove = train_dataset .features .keys ()
122
- train_dataset = train_dataset .with_format ("torch" ).map (
125
+ # After mapping to state next state, remove all other observations
126
+ obs_to_remove = set (train_dataset .features .keys ())
127
+ obs_to_remove .discard ('state' )
128
+ self .train_dataset = train_dataset .with_format ("torch" ).map (
123
129
DynamicsRecording .map_state_next_state , batched = True , fn_kwargs = {'state_observations' : self .state_obs },
124
130
remove_columns = tuple (obs_to_remove ))
125
- test_dataset = test_dataset .with_format ("torch" ).map (
131
+ self . test_dataset = test_dataset .with_format ("torch" ).map (
126
132
DynamicsRecording .map_state_next_state , batched = True , fn_kwargs = {'state_observations' : self .state_obs },
127
133
remove_columns = tuple (obs_to_remove ))
128
- val_dataset = val_dataset .with_format ("torch" ).map (
134
+ self . val_dataset = val_dataset .with_format ("torch" ).map (
129
135
DynamicsRecording .map_state_next_state , batched = True , fn_kwargs = {'state_observations' : self .state_obs },
130
136
remove_columns = tuple (obs_to_remove ))
131
137
132
- self ._train_dataset = train_dataset
133
- self ._test_dataset = test_dataset
134
- self ._val_dataset = val_dataset
138
+ # Configure the prediction dataloader for the approximating and evaluating the transfer operator. This will
139
+ # be a dataloader passing state and next state single step measurements:
140
+ datasets , dynamics_recording = get_dynamics_dataset (train_shards = train_data ,
141
+ test_shards = None ,
142
+ val_shards = None ,
143
+ train_pred_horizon = 1 ,
144
+ eval_pred_horizon = 1 ,
145
+ frames_per_step = self .frames_per_step ,
146
+ state_obs = self .state_obs ,
147
+ action_obs = self .action_obs )
148
+ transfer_op_train_dataset , _ , _ = datasets
149
+ self ._transfer_op_train_dataset = transfer_op_train_dataset .with_format ("torch" ).map (
150
+ DynamicsRecording .map_state_next_state , batched = True , fn_kwargs = {'state_observations' : self .state_obs },
151
+ remove_columns = tuple (obs_to_remove ))
135
152
136
153
# Rebuilt the ESCNN representations of measurements _________________________________________________________
137
154
# TODO: Handle dyn systems without symmetries
@@ -157,45 +174,11 @@ def prepare_data(self):
157
174
action_reps = [rep for frame_reps in action_reps for rep in frame_reps ] # flatten list of reps
158
175
self .action_field_type = FieldType (self .gspace , representations = action_reps )
159
176
160
- self ._train_dataloader = DataLoader (dataset = train_dataset , batch_size = self .batch_size ,
161
- num_workers = self .num_workers ,
162
- persistent_workers = True if self .num_workers > 0 else False ,
163
- collate_fn = self .data_augmentation_collate_fn if self .augment else
164
- self .collate_fn )
165
- batch_size = min (self .batch_size , test_dataset .dataset_size )
166
- self ._test_dataloader = DataLoader (dataset = test_dataset , batch_size = batch_size , shuffle = False ,
167
- persistent_workers = True if self .num_workers > 0 else False ,
168
- num_workers = self .num_workers ,
169
- # pin_memory=True,
170
- collate_fn = self .collate_fn )
171
- batch_size = min (self .batch_size , test_dataset .dataset_size )
172
- self ._val_dataloader = DataLoader (dataset = val_dataset , batch_size = batch_size , shuffle = False ,
173
- persistent_workers = True if self .num_workers > 0 else False ,
174
- num_workers = self .num_workers ,
175
- # pin_memory=True,
176
- collate_fn = self .collate_fn )
177
-
178
- # Configure the prediction dataloader for the approximating and evaluating the transfer operator. This will
179
- # be a dataloader passing state and next state single step measurements:
180
- datasets , dynamics_recording = get_dynamics_dataset (train_shards = train_data ,
181
- test_shards = test_data ,
182
- val_shards = val_data ,
183
- train_pred_horizon = 1 ,
184
- eval_pred_horizon = 1 ,
185
- frames_per_step = self .frames_per_step ,
186
- state_obs = self .state_obs ,
187
- action_obs = self .action_obs )
188
- transfer_op_train_dataset , _ , _ = datasets
189
- transfer_op_train_dataset = transfer_op_train_dataset .with_format ("torch" ).map (
190
- DynamicsRecording .map_state_next_state , batched = True , fn_kwargs = {'state_observations' : self .state_obs })
191
-
192
- self ._trans_op_dataloader = DataLoader (dataset = transfer_op_train_dataset ,
193
- batch_size = transfer_op_train_dataset .dataset_size , # Single batch
194
- pin_memory = False , num_workers = self .num_workers , shuffle = False ,
195
- collate_fn = self .collate_fn )
196
-
197
177
log .info (f"Data preparation done in { time .time () - start_time :.2f} [s]" )
198
178
179
+ def setup (self , stage : str ) -> None :
180
+ log .info (f"Setting up { stage } dataset" )
181
+
199
182
def compute_loss_metrics (self , predictions : dict , inputs : dict ) -> (torch .Tensor , dict ):
200
183
"""
201
184
Compute the loss and metrics from the predictions and inputs
@@ -206,20 +189,35 @@ def compute_loss_metrics(self, predictions: dict, inputs: dict) -> (torch.Tensor
206
189
raise NotImplementedError ("Implement this function in the derived class" )
207
190
208
191
def train_dataloader (self ):
209
- return self ._train_dataloader
192
+ return DataLoader (dataset = self .train_dataset , batch_size = self .batch_size ,
193
+ num_workers = self .num_workers ,
194
+ collate_fn = self .collate_fn ,
195
+ persistent_workers = True if self .num_workers > 0 else False , drop_last = False )
210
196
211
197
def val_dataloader (self ):
212
- return self ._val_dataloader
198
+ return DataLoader (dataset = self .val_dataset , batch_size = self .batch_size , shuffle = False ,
199
+ persistent_workers = True if self .num_workers > 0 else False ,
200
+ num_workers = self .num_workers ,
201
+ collate_fn = self .collate_fn ,
202
+ drop_last = False )
213
203
214
204
def test_dataloader (self ):
215
- return self ._test_dataloader
205
+ return DataLoader (dataset = self .test_dataset , batch_size = self .batch_size , shuffle = False ,
206
+ persistent_workers = True if self .num_workers > 0 else False ,
207
+ collate_fn = self .collate_fn ,
208
+ num_workers = self .num_workers , drop_last = False )
216
209
217
210
def predict_dataloader (self ):
218
- return self ._trans_op_dataloader
211
+ return DataLoader (dataset = self ._transfer_op_train_dataset , batch_size = self .batch_size , pin_memory = False ,
212
+ collate_fn = self .collate_fn ,
213
+ num_workers = self .num_workers , shuffle = False , drop_last = False )
214
+
215
+ def transfer_batch_to_device (self , batch : Any , device : torch .device , dataloader_idx : int ) -> Any :
216
+ return super ().transfer_batch_to_device (batch , device , dataloader_idx )
219
217
220
218
@property
221
219
def prepared (self ):
222
- return self ._train_dataloader is not None
220
+ return self .train_dataset is not None
223
221
224
222
def collate_fn (self , batch_list : list ) -> dict :
225
223
batch = torch .utils .data .default_collate (batch_list )
@@ -260,8 +258,8 @@ def plot_sample_trajs(self):
260
258
num_trajs = 5
261
259
fig = None
262
260
styles = {'Train' : dict (width = 3 , dash = 'solid' ),
263
- 'Test' : dict (width = 2 , dash = '2px' ),
264
- 'Val' : dict (width = 1 , dash = '5px' )}
261
+ 'Test' : dict (width = 2 , dash = '2px' ),
262
+ 'Val' : dict (width = 1 , dash = '5px' )}
265
263
for partition , dataloader in zip (['Train' , 'Test' , 'Val' ],
266
264
[self .train_dataloader (), self .test_dataloader (), self .val_dataloader ()]):
267
265
batch = next (iter (dataloader ))
@@ -280,25 +278,28 @@ def plot_sample_trajs(self):
280
278
281
279
# Find all dynamic systems recordings
282
280
path_to_data /= Path ('mini_cheetah' ) / 'recordings' / 'grass'
281
+ # path_to_data = Path('/home/danfoa/Projects/koopman_robotics/data/linear_system/group=C10-dim=10/n_constraints=0
282
+ # /f_time_constant=1.5[s]-frames=200-horizon=8.7[s]/noise_level=0')
283
283
path_to_dyn_sys_data = set ([a .parent for a in list (path_to_data .rglob ('*train.pkl' ))])
284
284
# Select a dynamical system
285
285
mock_path = path_to_dyn_sys_data .pop ()
286
286
287
287
data_module = DynamicsDataModule (data_path = mock_path ,
288
288
pred_horizon = 10 ,
289
- eval_pred_horizon = 100 ,
289
+ eval_pred_horizon = 200 ,
290
290
frames_per_step = 1 ,
291
291
num_workers = 1 ,
292
292
batch_size = 1000 ,
293
293
augment = False ,
294
- state_obs = ('q' , ),
295
- action_obs = tuple (),
294
+ state_obs = ('v' ),
295
+ # action_obs=tuple(),
296
296
)
297
297
298
298
# Test loading of the DynamicsRecording
299
299
data_module .prepare_data ()
300
+ s = next (iter (data_module .train_dataset ))
301
+
300
302
data_module .plot_sample_trajs ()
301
- s = next (iter (data_module ._train_dataset ))
302
303
states , state_trajs = None , None
303
304
fig = None
304
305
0 commit comments