Skip to content

Commit 4e31dc9

Browse files
committed
Add dataset reduction capability
This commit adds a "deterministic" partitioning of a DynamicsRecording
1 parent 1d4a770 commit 4e31dc9

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

data/DynamicsDataModule.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def prepare_data(self):
9999
datasets, dynamics_recording = get_dynamics_dataset(train_shards=train_data,
100100
test_shards=test_data,
101101
val_shards=val_data,
102+
train_ratio=self.train_ratio,
102103
train_pred_horizon=self.pred_horizon,
103104
eval_pred_horizon=self.eval_pred_horizon,
104105
test_pred_horizon=self.test_pred_horizon,

data/DynamicsRecording.py

+53-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from datasets import Features, IterableDataset
1212
from escnn.group import Representation, groups_dict
1313

14-
from utils.mysc import compare_dictionaries
14+
from utils.mysc import TemporaryNumpySeed, compare_dictionaries
1515

1616
log = logging.getLogger(__name__)
1717

@@ -217,12 +217,62 @@ def estimate_dataset_size(recordings: list[DynamicsRecording], prediction_horizo
217217
log.debug(f"Steps in prediction horizon {int(steps_pred_horizon)}")
218218
return num_trajs, num_samples
219219

220+
def reduce_dataset_size(recordings: Iterable[DynamicsRecording], train_ratio: float = 1.0):
221+
assert 0.0 < train_ratio <= 1.0, f"Invalid train ratio {train_ratio}"
222+
if train_ratio == 1.0:
223+
return recordings
224+
log.info(f"Reducing dataset size to {train_ratio * 100}%")
225+
# Ensure all training seeds use the same training data partitions
226+
with TemporaryNumpySeed(10):
227+
for r in recordings:
228+
# Decide to keep a ratio of the original trajectories
229+
num_trajs = r.info['num_traj']
230+
if num_trajs < 10: # Do not discard entire trajectories, but rather parts of the trajectories
231+
time_horizon = r.recordings[r.state_obs[0]].shape[1] # Take the time horizon from the first observation
232+
233+
# Split the trajectory into "virtual" subtrajectories, and discard some of them
234+
idx = np.arange(time_horizon)
235+
n_partitions = math.ceil((1 / (1 - train_ratio)))
236+
n_partitions = n_partitions * 10 if n_partitions < 10 else n_partitions
237+
# Ensure partitions of the same time duration
238+
partitions_idx = np.split(idx[:-(time_horizon % n_partitions)], indices_or_sections=n_partitions)
239+
partition_length = partitions_idx[0].shape[0]
240+
n_partitions_to_keep = math.ceil(n_partitions * train_ratio)
241+
partitions_to_keep = np.random.choice(range(n_partitions),
242+
size=n_partitions_to_keep,
243+
replace=False)
244+
print(partitions_to_keep)
245+
partitions_to_keep = [partitions_idx[i] for i in partitions_to_keep]
246+
247+
ratio_of_samples_removed = (time_horizon - (len(partitions_to_keep) * partition_length)) / time_horizon
248+
assert ratio_of_samples_removed - (1 - train_ratio) < 0.05, \
249+
(f"Requested to remove {(1 - train_ratio) * 100}% of the samples, "
250+
f"but removed {ratio_of_samples_removed * 100}%")
251+
252+
new_recordings = {}
253+
for obs_name, obs in r.recordings.items():
254+
new_obs_trajs = []
255+
for part_time_idx in partitions_to_keep:
256+
new_obs_trajs.append(obs[:, part_time_idx])
257+
new_recordings[obs_name] = np.concatenate(new_obs_trajs, axis=0)
258+
r.recordings = new_recordings
259+
r.info['num_traj'] = len(partitions_to_keep)
260+
r.info['trajectory_length'] = partition_length
261+
else: # Discard entire trajectories
262+
# Sample int(num_trajs * train_ratio) trajectories from the original recordings
263+
num_trajs_to_keep = math.ceil(num_trajs * train_ratio)
264+
idx = range(num_trajs)
265+
idx_to_keep = np.random.choice(idx, size=num_trajs_to_keep, replace=False)
266+
# Keep only the selected trajectories
267+
new_recordings = {k: v[idx_to_keep] for k, v in r.recordings.items()}
268+
r.recordings = new_recordings
220269

221270
def get_dynamics_dataset(train_shards: list[Path],
222271
test_shards: Optional[list[Path]] = None,
223272
val_shards: Optional[list[Path]] = None,
224273
num_proc: int = 1,
225274
frames_per_step: int = 1,
275+
train_ratio: float = 1.0,
226276
train_pred_horizon: Union[int, float] = 1,
227277
eval_pred_horizon: Union[int, float] = 10,
228278
test_pred_horizon: Union[int, float] = 10,
@@ -260,6 +310,8 @@ def get_dynamics_dataset(train_shards: list[Path],
260310
recordings = [DynamicsRecording.load_from_file(f, obs_names=relevant_obs) for f in partition_shards]
261311
if partition == "train":
262312
pred_horizon = train_pred_horizon
313+
if train_ratio < 1.0:
314+
reduce_dataset_size(recordings, train_ratio)
263315
elif partition == "val":
264316
pred_horizon = eval_pred_horizon
265317
else:
@@ -275,12 +327,8 @@ def get_dynamics_dataset(train_shards: list[Path],
275327
action_obs=tuple(action_obs))
276328
)
277329

278-
# for sample in dataset:
279330
log.debug(f"[Dataset {partition} - Trajs:{num_trajs} - Samples: {num_samples} - "
280331
f"Frames per sample : {frames_per_step}]-----------------------------")
281-
# log.debug(f"\tstate: {state.shape} = (frames_per_step, state_dim)")
282-
# log.debug(f"\tnext_state: {next_state.shape} = (pred_horizon, frames_per_step, state_dim)")
283-
# break
284332

285333
dataset.info.dataset_size = num_samples
286334
dataset.info.dataset_name = f"[{partition}] Linear dynamics"

utils/mysc.py

+11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313

1414
from torch import Tensor
1515

16+
class TemporaryNumpySeed:
17+
def __init__(self, seed):
18+
self.seed = seed
19+
self.state = None
20+
21+
def __enter__(self):
22+
self.state = np.random.get_state()
23+
np.random.seed(self.seed)
24+
25+
def __exit__(self, *args):
26+
np.random.set_state(self.state)
1627

1728
def powerset(iterable):
1829
"Return the list of all subsets of the input iterable"

0 commit comments

Comments
 (0)