Skip to content

Commit b8cd185

Browse files
committed
Shared encoder DPNet
1 parent f282b20 commit b8cd185

File tree

75 files changed

+78
-206
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+78
-206
lines changed

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=0/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=1/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=2/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=3/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=4/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=5/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=6/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=7/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=8/test_trajectories.html

-14
This file was deleted.

data/linear_system/group=SO(2)-dim=3/n_constraints=0/f_time_constant=inf[s]-frames=200-horizon=11.7[s]/noise_level=9/test_trajectories.html

-14
This file was deleted.

nn/DeepProjections.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
orth_w: float = 0.1,
5656
enforce_constant_fn: bool = True,
5757
use_spectral_score: bool = True,
58-
aux_obs_space: bool = False,
58+
shared_encoder: bool = True,
5959
obs_fn_params: Optional[dict] = None,
6060
linear_decoder: bool = True,
6161
**markov_dyn_params
@@ -67,7 +67,7 @@ def __init__(
6767
self.ck_w = ck_w
6868
self.orth_w = orth_w
6969
self.use_spectral_score = use_spectral_score
70-
self.aux_obs_space = aux_obs_space
70+
self.shared_encoder = shared_encoder
7171
self.inverse_projector = None # if linear decoder is true, this is the map between obs to states.
7272
self.inverse_projector_bias = None
7373
self.linear_decoder = linear_decoder
@@ -169,7 +169,7 @@ def compute_loss_and_metrics(self,
169169
return loss, obs_space_metrics
170170

171171
def get_obs_space_metrics(self, obs_state_traj: Tensor, obs_state_traj_aux: Optional[Tensor] = None) -> dict:
172-
if obs_state_traj_aux is None and self.aux_obs_space:
172+
if obs_state_traj_aux is None and self.shared_encoder:
173173
raise ValueError("aux_obs_space is True but obs_state_traj_aux is None")
174174
# Compute Covariance and Cross-Covariance operators for the observation state space.
175175
# Spectral and Projection scores, and CK loss terms.
@@ -270,14 +270,18 @@ def approximate_transfer_operator(self, train_data_loader: DataLoader):
270270
def build_obs_fn(self, num_layers, identity=False, **kwargs):
271271
if identity:
272272
return lambda x: (x, x)
273-
obs_fn = MLP(in_dim=self.state_dim, out_dim=self.obs_state_dim, num_layers=num_layers,
274-
head_with_activation=False, **kwargs)
275-
obs_fn_aux = None
276-
if self.aux_obs_space:
277-
obs_fn_aux = MLP(in_dim=self.state_dim, out_dim=self.obs_state_dim, num_layers=num_layers,
278-
head_with_activation=False, **kwargs)
279-
280-
return ObservableNet(obs_fn=obs_fn, obs_fn_aux=obs_fn_aux)
273+
num_hidden_units = kwargs['num_hidden_units']
274+
# Define the feature extractor used by the observable function.
275+
encoder = MLP(in_dim=self.state_dim,
276+
out_dim=num_hidden_units,
277+
num_layers=num_layers,
278+
head_with_activation=True, **kwargs)
279+
aux_encoder = None
280+
if not self.shared_encoder:
281+
aux_encoder = MLP(in_dim=self.state_dim, out_dim=num_hidden_units, num_layers=num_layers,
282+
head_with_activation=True, **kwargs)
283+
284+
return ObservableNet(encoder=encoder, aux_encoder=aux_encoder, obs_dim=self.obs_state_dim)
281285

282286
def build_inv_obs_fn(self, num_layers, linear_decoder: bool, identity=False, **kwargs):
283287
if identity:

nn/EquivDeepPojections.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -273,20 +273,20 @@ def build_obs_fn(self, num_layers, **kwargs):
273273
in_type=self.state_type_iso,
274274
desired_hidden_units=num_hidden_units)
275275

276-
obs_fn = EMLP(in_type=self.state_type_iso,
277-
out_type=self.obs_state_type,
278-
num_layers=num_layers,
279-
activation=act,
280-
**kwargs)
281-
obs_fn_aux = None
282-
if self.aux_obs_space:
283-
obs_fn_aux = EMLP(in_type=self.state_type_iso,
284-
out_type=self.obs_state_type,
285-
num_layers=num_layers,
286-
activation=act,
287-
**kwargs)
288-
289-
return ObservableNet(obs_fn=obs_fn, obs_fn_aux=obs_fn_aux)
276+
encoder = EMLP(in_type=self.state_type_iso,
277+
out_type=act.out_type,
278+
num_layers=num_layers,
279+
activation=act,
280+
**kwargs)
281+
aux_encoder = None
282+
if not self.shared_encoder:
283+
aux_encoder = EMLP(in_type=self.state_type_iso,
284+
out_type=act.out_type,
285+
num_layers=num_layers,
286+
activation=act,
287+
**kwargs)
288+
289+
return ObservableNet(encoder=encoder, aux_encoder=aux_encoder, obs_type=self.obs_state_type)
290290

291291
def build_inv_obs_fn(self, num_layers, linear_decoder: bool, **kwargs):
292292
if linear_decoder:

nn/LightningLatentMarkovDynamics.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def training_step(self, batch, batch_idx):
6666

6767
self.log("loss/train", loss, prog_bar=False)
6868
self.log_metrics(scalar_metrics, suffix="train", batch_size=self._batch_size)
69-
# self.log_vector_metrics(vector_metrics, type_sufix="train", batch_size=self._batch_size)
69+
self.log_vector_metrics(vector_metrics, type_sufix="train", batch_size=self._batch_size)
7070
return loss
7171

7272
def validation_step(self, batch, batch_idx):
@@ -76,7 +76,7 @@ def validation_step(self, batch, batch_idx):
7676

7777
self.log("loss/val", loss, prog_bar=False)
7878
self.log_metrics(scalar_metrics, suffix="val", batch_size=self._batch_size)
79-
# self.log_vector_metrics(vector_metrics, type_sufix="val", batch_size=self._batch_size)
79+
self.log_vector_metrics(vector_metrics, type_sufix="val", batch_size=self._batch_size)
8080
return {'output': outputs, 'input': batch}
8181

8282
def test_step(self, batch, batch_idx):
@@ -117,7 +117,7 @@ def on_fit_start(self) -> None:
117117
self._loss_metrics_fn = loss_metrics_fn
118118

119119
def on_train_start(self):
120-
self.log("noise_level", self.trainer.datamodule.noise_level, prog_bar=False, on_epoch=True)
120+
# self.log("noise_level", self.trainer.datamodule.noise_level, prog_bar=False, on_epoch=True)
121121

122122
if hasattr(self.model, "approximate_transfer_operator"):
123123
metrics = self.model.approximate_transfer_operator(self.trainer.datamodule.predict_dataloader())
@@ -147,7 +147,7 @@ def on_train_end(self) -> None:
147147
self.compute_figure_metrics(self.val_metrics_fn, self.trainer.datamodule.train_dataloader(), suffix="train")
148148

149149
def on_validation_start(self) -> None:
150-
if hasattr(self.model, "approximate_transfer_operator"):
150+
if hasattr(self.model, "approximate_transfer_operator") and self.trainer.current_epoch % 2 == 0:
151151
metrics = self.model.approximate_transfer_operator(self.trainer.datamodule.predict_dataloader())
152152
vector_metrics, scalar_metrics = self.separate_vector_scalar_metrics(metrics)
153153
self.log_metrics(scalar_metrics, suffix='')
@@ -192,13 +192,18 @@ def log_vector_metrics(self, metrics: Optional[dict]=None, type_sufix='', batch_
192192

193193
for metric, vector in flat_metrics.items():
194194
assert vector.ndim >= 1, f"Vector metric {metric} has to be of shape (n_samples,) or (batch, time_steps)."
195-
# Separate the last _sufix part from the key to obtain the metric name.
196-
tmp = metric.split('_') # Average value will use this name, vector metric will use the full name.
197-
metric_name, metric_sufix = '_'.join(tmp[:-1]), tmp[-1]
198-
199195
metric_log_name = f"{metric}/{type_sufix}"
200-
# self.log(metric_log_name, torch.mean(vector), prog_bar=False, batch_size=batch_size)
196+
if "_t/" in metric_log_name or "_dist/" in metric_log_name:
197+
# Separate the last _sufix part from the key to obtain the metric name.
198+
tmp = metric.split('_') # Average value will use this name, vector metric will use the full name.
199+
metric_name, metric_sufix = '_'.join(tmp[:-1]), tmp[-1]
200+
self.log(f"{metric_name}/{type_sufix}", torch.mean(vector), prog_bar=False, batch_size=batch_size)
201+
else:
202+
self.log(metric_log_name, torch.mean(vector), prog_bar=False, batch_size=batch_size)
201203

204+
if type_sufix == 'train' or type_sufix == 'val':
205+
continue
206+
202207
if metric_log_name in self._log_cache:
203208
self._log_cache[metric_log_name] = np.concatenate([self._log_cache[metric_log_name], vector.detach().cpu().numpy()], axis=0)
204209
else:

nn/ObservableNet.py

+34-31
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,47 @@
55
import torch.nn
66
from escnn.nn import EquivariantModule
77

8-
from nn.EquivLinearDynamics import EquivLinearDynamics
9-
from nn.LinearDynamics import LinearDynamics
10-
from nn.mlp import MLP
11-
from nn.emlp import EMLP
12-
13-
148
class ObservableNet(torch.nn.Module):
15-
16-
def __init__(self,
17-
obs_fn: Union[torch.nn.Module, EquivariantModule],
18-
obs_fn_aux: Optional[Union[torch.nn.Module, EquivariantModule]] = None):
9+
"""
10+
A network computing the observation state in the initial observable space H with measure μ(t) and the observable
11+
space H' representing the observable space after a Δt step, which potentially has a different measure μ(t+Δt).
12+
"""
13+
def __init__(self,
14+
encoder: Union[torch.nn.Module, EquivariantModule],
15+
aux_encoder: Optional[Union[torch.nn.Module, EquivariantModule]] = None,
16+
obs_dim: Optional[int] = None,
17+
obs_type: Optional[escnn.nn.FieldType] = None):
1918
super().__init__()
20-
self.equivariant = isinstance(obs_fn, EquivariantModule)
21-
self.use_aux_obs_fn = obs_fn_aux is not None
22-
23-
self.obs = obs_fn
24-
self.obs_aux = None
25-
if self.use_aux_obs_fn: # Use two twin networks to compute the main and auxiliary observable space.
26-
self.obs_aux = obs_fn_aux
19+
self.equivariant = isinstance(encoder, EquivariantModule)
20+
self.use_aux_encoder = aux_encoder is not None
21+
22+
self.encoder = encoder
23+
self.aux_encoder = aux_encoder if self.use_aux_encoder else None
24+
25+
# Setting the bias of the linear layer to true is equivalent to setting the constant function in the basis
26+
# of the space of functions. Then the bias of each dimension is the coefficient of the constant function.
27+
if self.equivariant:
28+
# Bias term (a.k.a the constant function) is present only on the trivial isotypic subspace
29+
assert obs_type is not None, f"obs state Field type must be provided when using equivariant encoder"
30+
self.obs_H = escnn.nn.Linear(
31+
in_type=self.encoder.out_type, out_type=obs_type, bias=True)
32+
self.obs_H_prime = escnn.nn.Linear(
33+
in_type=self.encoder.out_type, out_type=obs_type, bias=True)
2734
else:
28-
# Setting the bias of the linear layer to true is equivalent to setting the constant function in the basis
29-
# of the space of functions. Then the bias of each dimension is the coefficient of the constant function.
30-
if self.equivariant:
31-
self.transfer_op_H_H_prime = escnn.nn.Linear(
32-
in_type=self.obs.out_type, out_type=self.obs.out_type, bias=True)
33-
else:
34-
self.transfer_op_H_H_prime = torch.nn.Linear(
35-
in_features=self.obs.out_dim, out_features=self.obs.out_dim, bias=True)
35+
assert obs_dim is not None, f"obs state dimension must be provided when using non-equivariant encoder"
36+
self.obs_H = torch.nn.Linear(
37+
in_features=self.encoder.out_dim, out_features=obs_dim, bias=True)
38+
self.obs_H_prime = torch.nn.Linear(
39+
in_features=self.encoder.out_dim, out_features=obs_dim, bias=True)
3640

3741
def forward(self, input):
3842

39-
obs_state = self.obs(input)
43+
features = self.encoder(input)
44+
aux_features = self.aux_encoder(input) if self.use_aux_encoder else features
4045

41-
if self.use_aux_obs_fn:
42-
obs_aux_state = self.obs_aux(input)
43-
else:
44-
obs_aux_state = self.transfer_op_H_H_prime(obs_state)
46+
obs_state_H = self.obs_H(features)
47+
obs_state_H_prime = self.obs_H_prime(aux_features)
4548

46-
return obs_state, obs_aux_state
49+
return obs_state_H, obs_state_H_prime
4750

4851

nn/emlp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_activation(activation, in_type: FieldType, desired_hidden_units: int):
120120
irreps=list(unique_irreps),
121121
function=f"p_{activation.lower()}",
122122
inplace=True,
123-
type='regular' if not group.continuous else 'rand',
123+
type='rand', #'regular' if not group.continuous else 'rand'#TODO: fix this
124124
N=grid_length)
125125

126126
def forward(self, x):

0 commit comments

Comments
 (0)