11
11
from datasets import Features , IterableDataset
12
12
from escnn .group import Representation , groups_dict
13
13
14
- from utils .mysc import compare_dictionaries
14
+ from utils .mysc import TemporaryNumpySeed , compare_dictionaries
15
15
16
16
log = logging .getLogger (__name__ )
17
17
@@ -217,12 +217,62 @@ def estimate_dataset_size(recordings: list[DynamicsRecording], prediction_horizo
217
217
log .debug (f"Steps in prediction horizon { int (steps_pred_horizon )} " )
218
218
return num_trajs , num_samples
219
219
220
+ def reduce_dataset_size (recordings : Iterable [DynamicsRecording ], train_ratio : float = 1.0 ):
221
+ assert 0.0 < train_ratio <= 1.0 , f"Invalid train ratio { train_ratio } "
222
+ if train_ratio == 1.0 :
223
+ return recordings
224
+ log .info (f"Reducing dataset size to { train_ratio * 100 } %" )
225
+ # Ensure all training seeds use the same training data partitions
226
+ with TemporaryNumpySeed (10 ):
227
+ for r in recordings :
228
+ # Decide to keep a ratio of the original trajectories
229
+ num_trajs = r .info ['num_traj' ]
230
+ if num_trajs < 10 : # Do not discard entire trajectories, but rather parts of the trajectories
231
+ time_horizon = r .recordings [r .state_obs [0 ]].shape [1 ] # Take the time horizon from the first observation
232
+
233
+ # Split the trajectory into "virtual" subtrajectories, and discard some of them
234
+ idx = np .arange (time_horizon )
235
+ n_partitions = math .ceil ((1 / (1 - train_ratio )))
236
+ n_partitions = n_partitions * 10 if n_partitions < 10 else n_partitions
237
+ # Ensure partitions of the same time duration
238
+ partitions_idx = np .split (idx [:- (time_horizon % n_partitions )], indices_or_sections = n_partitions )
239
+ partition_length = partitions_idx [0 ].shape [0 ]
240
+ n_partitions_to_keep = math .ceil (n_partitions * train_ratio )
241
+ partitions_to_keep = np .random .choice (range (n_partitions ),
242
+ size = n_partitions_to_keep ,
243
+ replace = False )
244
+ print (partitions_to_keep )
245
+ partitions_to_keep = [partitions_idx [i ] for i in partitions_to_keep ]
246
+
247
+ ratio_of_samples_removed = (time_horizon - (len (partitions_to_keep ) * partition_length )) / time_horizon
248
+ assert ratio_of_samples_removed - (1 - train_ratio ) < 0.05 , \
249
+ (f"Requested to remove { (1 - train_ratio ) * 100 } % of the samples, "
250
+ f"but removed { ratio_of_samples_removed * 100 } %" )
251
+
252
+ new_recordings = {}
253
+ for obs_name , obs in r .recordings .items ():
254
+ new_obs_trajs = []
255
+ for part_time_idx in partitions_to_keep :
256
+ new_obs_trajs .append (obs [:, part_time_idx ])
257
+ new_recordings [obs_name ] = np .concatenate (new_obs_trajs , axis = 0 )
258
+ r .recordings = new_recordings
259
+ r .info ['num_traj' ] = len (partitions_to_keep )
260
+ r .info ['trajectory_length' ] = partition_length
261
+ else : # Discard entire trajectories
262
+ # Sample int(num_trajs * train_ratio) trajectories from the original recordings
263
+ num_trajs_to_keep = math .ceil (num_trajs * train_ratio )
264
+ idx = range (num_trajs )
265
+ idx_to_keep = np .random .choice (idx , size = num_trajs_to_keep , replace = False )
266
+ # Keep only the selected trajectories
267
+ new_recordings = {k : v [idx_to_keep ] for k , v in r .recordings .items ()}
268
+ r .recordings = new_recordings
220
269
221
270
def get_dynamics_dataset (train_shards : list [Path ],
222
271
test_shards : Optional [list [Path ]] = None ,
223
272
val_shards : Optional [list [Path ]] = None ,
224
273
num_proc : int = 1 ,
225
274
frames_per_step : int = 1 ,
275
+ train_ratio : float = 1.0 ,
226
276
train_pred_horizon : Union [int , float ] = 1 ,
227
277
eval_pred_horizon : Union [int , float ] = 10 ,
228
278
test_pred_horizon : Union [int , float ] = 10 ,
@@ -260,6 +310,8 @@ def get_dynamics_dataset(train_shards: list[Path],
260
310
recordings = [DynamicsRecording .load_from_file (f , obs_names = relevant_obs ) for f in partition_shards ]
261
311
if partition == "train" :
262
312
pred_horizon = train_pred_horizon
313
+ if train_ratio < 1.0 :
314
+ reduce_dataset_size (recordings , train_ratio )
263
315
elif partition == "val" :
264
316
pred_horizon = eval_pred_horizon
265
317
else :
@@ -275,12 +327,8 @@ def get_dynamics_dataset(train_shards: list[Path],
275
327
action_obs = tuple (action_obs ))
276
328
)
277
329
278
- # for sample in dataset:
279
330
log .debug (f"[Dataset { partition } - Trajs:{ num_trajs } - Samples: { num_samples } - "
280
331
f"Frames per sample : { frames_per_step } ]-----------------------------" )
281
- # log.debug(f"\tstate: {state.shape} = (frames_per_step, state_dim)")
282
- # log.debug(f"\tnext_state: {next_state.shape} = (pred_horizon, frames_per_step, state_dim)")
283
- # break
284
332
285
333
dataset .info .dataset_size = num_samples
286
334
dataset .info .dataset_name = f"[{ partition } ] Linear dynamics"
0 commit comments