|
| 1 | +import copy |
1 | 2 | import logging
|
2 | 3 | import math
|
3 | 4 | import pickle
|
@@ -35,9 +36,13 @@ class DynamicsRecording:
|
35 | 36 | def save_to_file(self, file_path: Path):
|
36 | 37 | # Store representations and groups without serializing
|
37 | 38 | if len(self.obs_representations) > 0:
|
38 |
| - self._obs_rep_irreps = {k: rep.irreps for k, rep in self.obs_representations.items()} |
39 |
| - self._obs_rep_names = {k: rep.name for k, rep in self.obs_representations.items()} |
40 |
| - self._obs_rep_Q = {k: rep.change_of_basis for k, rep in self.obs_representations.items()} |
| 39 | + self._obs_rep_irreps = {} |
| 40 | + self._obs_rep_names = {} |
| 41 | + self._obs_rep_Q = {} |
| 42 | + for k, rep in self.obs_representations.items(): |
| 43 | + self._obs_rep_irreps[k] = rep.irreps if rep is not None else None |
| 44 | + self._obs_rep_names[k] = rep.name if rep is not None else None |
| 45 | + self._obs_rep_Q[k] = rep.change_of_basis if rep is not None else None |
41 | 46 | group = self.obs_representations[self.state_obs[0]].group
|
42 | 47 | self._group_keys = group._keys
|
43 | 48 | self._group_name = group.__class__.__name__
|
@@ -335,7 +340,7 @@ def reduce_dataset_size(recordings: Iterable[DynamicsRecording], train_ratio: fl
|
335 | 340 | return recordings
|
336 | 341 | log.info(f"Reducing dataset size to {train_ratio * 100}%")
|
337 | 342 | # Ensure all training seeds use the same training data partitions
|
338 |
| - from utils.mysc import TemporaryNumpySeed |
| 343 | + from morpho_symm.utils.mysc import TemporaryNumpySeed |
339 | 344 | with TemporaryNumpySeed(10):
|
340 | 345 | for r in recordings:
|
341 | 346 | # Decide to keep a ratio of the original trajectories
|
@@ -380,6 +385,54 @@ def reduce_dataset_size(recordings: Iterable[DynamicsRecording], train_ratio: fl
|
380 | 385 | new_recordings = {k: v[idx_to_keep] for k, v in r.recordings.items()}
|
381 | 386 | r.recordings = new_recordings
|
382 | 387 |
|
| 388 | +def split_train_val_test( |
| 389 | + dyn_recording: DynamicsRecording, partition_sizes=(0.70, 0.15, 0.15)) -> tuple[DynamicsRecording]: |
| 390 | + assert np.isclose(np.sum(partition_sizes), 1.0), f"Invalid partition sizes {partition_sizes}" |
| 391 | + partitions_names = ["train", "val", "test"] |
| 392 | + |
| 393 | + log.info(f"Partitioning {dyn_recording.description} into train/val/test of sizes {partition_sizes}[%]") |
| 394 | + # Ensure all training seeds use the same training data partitions |
| 395 | + from morpho_symm.utils.mysc import TemporaryNumpySeed |
| 396 | + with TemporaryNumpySeed(10): # Ensure deterministic behavior |
| 397 | + # Decide to keep a ratio of the original trajectories |
| 398 | + num_trajs = int(dyn_recording.info['num_traj']) |
| 399 | + if num_trajs < 10: # Do not discard entire trajectories, but rather parts of the trajectories |
| 400 | + # Take the time horizon from the first observation |
| 401 | + sample_obs = dyn_recording.recordings[dyn_recording.state_obs[0]] |
| 402 | + if len(sample_obs.shape) == 3: # [traj, time, obs_dim] |
| 403 | + time_horizon = sample_obs.shape[1] |
| 404 | + elif len(sample_obs.shape) == 2: # [traj, obs_dim] |
| 405 | + time_horizon = sample_obs.shape[0] |
| 406 | + else: |
| 407 | + raise RuntimeError(f"Invalid shape {sample_obs.shape} of {dyn_recording.state_obs[0]}") |
| 408 | + |
| 409 | + num_samples = time_horizon |
| 410 | + min_idx = 0 |
| 411 | + partitions_sample_idx = {partition: None for partition in partitions_names} |
| 412 | + for partition_name, ratio in zip(partitions_names, partition_sizes): |
| 413 | + max_idx = min_idx + int(num_samples * ratio) |
| 414 | + partitions_sample_idx[partition_name] = list(range(min_idx, max_idx)) |
| 415 | + min_idx = min_idx + int(num_samples * ratio) |
| 416 | + |
| 417 | + # TODO: Avoid deep copying the data itself. |
| 418 | + partitions_recordings = {partition: copy.deepcopy(dyn_recording) for partition in partitions_names} |
| 419 | + for partition_name, sample_idx in partitions_sample_idx.items(): |
| 420 | + part_num_samples = len(sample_idx) |
| 421 | + partitions_recordings[partition_name].info['trajectory_length'] = part_num_samples |
| 422 | + partitions_recordings[partition_name].recordings = dict() |
| 423 | + for obs_name in dyn_recording.recordings.keys(): |
| 424 | + if len(dyn_recording.recordings[obs_name].shape) == 3: |
| 425 | + data = dyn_recording.recordings[obs_name][:, sample_idx] |
| 426 | + elif len(dyn_recording.recordings[obs_name].shape) == 2: |
| 427 | + data = dyn_recording.recordings[obs_name][sample_idx] |
| 428 | + else: |
| 429 | + raise RuntimeError(f"Invalid shape {dyn_recording.recordings[obs_name].shape} of {obs_name}") |
| 430 | + partitions_recordings[partition_name].recordings[obs_name] = data |
| 431 | + |
| 432 | + return partitions_recordings['train'], partitions_recordings['val'], partitions_recordings['test'] |
| 433 | + else: # Discard entire trajectories |
| 434 | + raise NotImplementedError() |
| 435 | + |
383 | 436 |
|
384 | 437 | def get_dynamics_dataset(train_shards: list[Path],
|
385 | 438 | test_shards: Optional[list[Path]] = None,
|
|
0 commit comments