Skip to content

Commit 5f66aff

Browse files
committed
Equiv DPNets working-Iso basis bug fixed
1 parent 949e0f5 commit 5f66aff

18 files changed

+1014
-669
lines changed

cfg/config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ hydra:
3737
chdir: True
3838
env_set:
3939
XLA_PYTHON_CLIENT_PREALLOCATE: 'false'
40-
HYDRA_FULL_ERROR: '1'
40+
HYDRA_FULL_ERROR: 1
4141
# CUDA_VISIBLE_DEVICES: ${.device}
4242
config:
4343
override_dirname:
@@ -54,6 +54,8 @@ hydra:
5454
- model.equivariant
5555
- model.max_ck_window_length
5656
- model.activation
57+
- model.num_layers
58+
- model.num_hidden_units
5759
- model.lr
5860
- model.batch_norm
5961
- model.bias

cfg/model/base_model.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ augment: False # Weather to use data-augmentation of the dataset if the
66
# Architectural and Task parameters
77
equivariant: True # Impose equivariance constraints on NNs
88
activation: ReLU
9-
n_layers: 4 # Number MLPs' layers (including input and output layers)
9+
num_layers: 4 # Number MLPs' layers (including input and output layers)
1010
bias: True # Use bias in the MLPs
1111
batch_norm: True # Use batch normalization layers before activation functions in MLPs
1212

cfg/model/dpnet.yaml

+13-8
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,25 @@ defaults:
33

44
name: DPNet
55
# Model hyperparameters
6+
7+
# Symmetry exploitation parameters
68
equivariant: False
7-
activation: Identity
8-
n_layers: 2 # Number MLPs' layers (including input and output layers)
9-
batch_norm: False # Something wrong happens when we turn batch norm on. Performance goes to hell.
9+
group_avg_trick: True
10+
11+
activation: ELU
12+
num_layers: 4 # Number MLPs' layers (including input and output layers)
13+
num_hidden_units: 128 # Number of hidden units in each layer
14+
batch_norm: ${model.equivariant} # Something wrong happens when we turn batch norm on. Performance goes to hell.
1015
bias: False
1116
# Optimization hyperparameters parameters
1217
lr: 1e-3
1318
batch_size: 1024
1419

1520
max_ck_window_length: ${system.pred_horizon} # Maximum length of the Chapman-Kolmogorov window
16-
ck_w: 0.0 # Weight of the Chapman-Kolmogorov regularization term in the loss function
17-
orth_w: 0.1 # Weight of the orthonormal regularization term in the loss function
18-
aux_obs_space: False # Whether to use an auxiliary observable space.
19-
use_spectral_score: True # Whether to use the spectral or the correlation score
21+
ck_w: 0.0 # Weight of the Chapman-Kolmogorov regularization term in the loss function
22+
orth_w: 0.5 # Weight of the orthonormal regularization term in the loss function
23+
aux_obs_space: False # Whether to use an auxiliary observable space.
24+
use_spectral_score: True # Whether to use the spectral or the correlation score
2025

21-
summary: ${model.name}-Equiv:${model.equivariant}-CK_w:${model.ck_w}-Orth_w:${model.orth_w}-Win:${model.max_ck_window_length}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.n_layers}-AOS:${model.aux_obs_space}-SS:${model.use_spectral_score}
26+
summary: ${model.name}-Equiv:${model.equivariant}-CK_w:${model.ck_w}-Orth_w:${model.orth_w}-Win:${model.max_ck_window_length}-Act:${model.activation}-B:${model.bias}-BN:${model.batch_norm}-LR:${model.lr}-L:${model.num_layers}-${model.num_hidden_units}-AOS:${model.aux_obs_space}-SS:${model.use_spectral_score}
2227

cfg/system/linear_system.yaml

+6-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ frames_per_state: 1 # Number of time-frames to use as a Markov Process
1010
pred_horizon: 10 # Number (or percentage) of Markov Process state time steps to predict into the future
1111
eval_pred_horizon: 100 # Number (or percentage) of Markov Process state time steps to predict into the future
1212

13-
data_path: '${system.name}/${system.state_dim}-dim/'
13+
group: C3
14+
noise_level: 0
15+
n_constraints: 0
1416

15-
summary: S:${system.state_dim}-OS:${system.obs_state_dim}-H:${system.pred_horizon}-Hv:${system.eval_pred_horizon}-F:${system.frames_per_state}
17+
data_path: '${system.name}/group=${system.group}-dim=${system.state_dim}/'
18+
19+
summary: S:${system.state_dim}-OS:${system.obs_state_dim}-H:${system.pred_horizon}-G:${system.group}-N:${system.noise_level}

data/DynamicsDataModule.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self,
2525
data_path: Path,
2626
pred_horizon: Union[int, float] = 0.25,
2727
eval_pred_horizon: Union[int, float] = 0.5,
28+
system_cfg: Optional[dict] = None,
2829
batch_size: int = 256,
2930
frames_per_step: int = 1,
3031
num_workers: int = 0,
@@ -34,8 +35,11 @@ def __init__(self,
3435
action_measurements: Optional[list[str]] = None,
3536
):
3637
super().__init__()
38+
if system_cfg is None:
39+
system_cfg = {}
3740
assert data_path.exists(), f"Data folder not found {data_path.absolute()}"
3841
self._data_path = data_path
42+
self.system_cfg = system_cfg
3943
self.augment = augment
4044
self.frames_per_step = frames_per_step
4145
if isinstance(pred_horizon, float):
@@ -74,11 +78,12 @@ def prepare_data(self):
7478
start_time = time.time()
7579
log.info(f"Preparing datasets {self._data_path}")
7680

77-
path_to_dyn_sys_data = set([a.parent for a in list(self._data_path.rglob('*train.pkl'))])
78-
# TODO: Handle multiple files from
79-
system_data_path = path_to_dyn_sys_data.pop()
80-
if len(path_to_dyn_sys_data) > 1:
81-
raise NotImplementedError("Multiple dynamical systems not supported yet")
81+
dyn_sys_data = set([a.parent for a in list(self._data_path.rglob('*train.pkl'))])
82+
noise_level = self.system_cfg.get('noise_level', 0)
83+
system_data_path = [path for path in dyn_sys_data if f"noise_level={noise_level}" in str(path)]
84+
if len(system_data_path) > 1:
85+
raise RuntimeError(f"Multiple potential paths {system_data_path} found")
86+
system_data_path = system_data_path.pop()
8287

8388
train_data, test_data, val_data = get_train_test_val_file_paths(system_data_path)
8489
# Obtain hugging face Iterable datasets instances

0 commit comments

Comments
 (0)