Skip to content

Commit

Permalink
Multimodal ResNet (ML4GW#143)
Browse files Browse the repository at this point in the history
* add function to return time and frequency  domain waveforms

* add multimodal resnet

* switch from fft to rfft

* run pre-commit

* add docstring

* add prefix freq for frequency domain waveforms

* update docstring, add args for time and freq separately

* switch to relative import

* add test for multimodal embedding

* pass patience correctly through the scheduler

* remove asds from datamodule.inject output

* hardcode patience=20

* remove patience as a class variable

* remove asds

* add scheduler patience and factor as arguments of the model

* remove time_and_frequency_domain_strain
  • Loading branch information
ravioli1369 authored Oct 24, 2024
1 parent 7014809 commit 4d9319c
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 14 deletions.
1 change: 1 addition & 0 deletions amplfi/train/architectures/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dense import CoherentDenseEmbedding, NChannelDenseEmbedding
from .flattener import Flattener
from .multimodal import MultiModal
from .resnet import ResNet
70 changes: 70 additions & 0 deletions amplfi/train/architectures/embeddings/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Literal, Optional

import torch
from ml4gw.nn.norm import NormLayer
from ml4gw.nn.resnet.resnet_1d import ResNet1D

from .base import Embedding


class MultiModal(Embedding):
def __init__(
self,
num_ifos: int,
time_context_dim: int,
freq_context_dim: int,
time_layers: list[int],
freq_layers: list[int],
time_kernel_size: int = 3,
freq_kernel_size: int = 3,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
stride_type: Optional[list[Literal["stride", "dilation"]]] = None,
norm_layer: Optional[NormLayer] = None,
**kwargs
):
"""
MultiModal embedding network that embeds both time and frequency data.
We pass the data through their own ResNets defined by their layers
and context dims, then concatenate the output embeddings.
"""
super().__init__()
self.time_domain_resnet = ResNet1D(
in_channels=num_ifos,
layers=time_layers,
classes=time_context_dim,
kernel_size=time_kernel_size,
zero_init_residual=zero_init_residual,
groups=groups,
width_per_group=width_per_group,
stride_type=stride_type,
norm_layer=norm_layer,
)
self.frequency_domain_resnet = ResNet1D(
in_channels=int(num_ifos * 2),
layers=freq_layers,
classes=freq_context_dim,
kernel_size=freq_kernel_size,
zero_init_residual=zero_init_residual,
groups=groups,
width_per_group=width_per_group,
stride_type=stride_type,
norm_layer=norm_layer,
)

# set the context dimension so
# the flow can access it
self.context_dim = time_context_dim + freq_context_dim

def forward(self, X):
time_domain_embedded = self.time_domain_resnet(X)
X_fft = torch.fft.rfft(X)
X_fft = torch.cat((X_fft.real, X_fft.imag), dim=1)
frequency_domain_embedded = self.frequency_domain_resnet(X_fft)

embedding = torch.concat(
(time_domain_embedded, frequency_domain_embedded), dim=1
)
return embedding
10 changes: 2 additions & 8 deletions amplfi/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def on_train_start(self, trainer, pl_module):
X = X.to(device)

cross, plus, parameters = datamodule.waveform_sampler.sample(X)
strain, asds, parameters = datamodule.inject(
X, cross, plus, parameters
)
strain, parameters = datamodule.inject(X, cross, plus, parameters)

# save an example validation batch
# and parameters to disk
Expand All @@ -59,7 +57,7 @@ def on_train_start(self, trainer, pl_module):
val_parameters = {
k: val_parameters[:, i] for i, k in enumerate(keys)
}
val_strain, val_asds, val_parameters = datamodule.inject(
val_strain, val_parameters = datamodule.inject(
background, val_cross, val_plus, val_parameters
)

Expand All @@ -72,30 +70,26 @@ def on_train_start(self, trainer, pl_module):
with h5py.File(f, "w") as h5file:
h5file["strain"] = strain.cpu().numpy()
h5file["parameters"] = parameters.cpu().numpy()
h5file["asds"] = asds.cpu().numpy()
s3_file.write(f.getvalue())

with s3.open(f"{save_dir}/val-batch.h5", "wb") as s3_file:
with io.BytesIO() as f:
with h5py.File(f, "w") as h5file:
h5file["strain"] = val_strain.cpu().numpy()
h5file["parameters"] = val_parameters.cpu().numpy()
h5file["asds"] = val_asds.cpu().numpy()
s3_file.write(f.getvalue())
else:
with h5py.File(
os.path.join(save_dir, "train-batch.h5"), "w"
) as f:
f["strain"] = strain.cpu().numpy()
f["parameters"] = parameters.cpu().numpy()
f["asds"] = asds.cpu().numpy()

with h5py.File(
os.path.join(save_dir, "val-batch.h5"), "w"
) as f:
f["strain"] = val_strain.cpu().numpy()
f["parameters"] = val_parameters.cpu().numpy()
f["asds"] = val_asds.cpu().numpy()


class SaveAugmentedSimilarityBatch(pl.Callback):
Expand Down
3 changes: 2 additions & 1 deletion amplfi/train/configs/flow/cbc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ model:
class_path: ml4gw.nn.norm.GroupNorm1DGetter
init_args:
groups: 8
patience: null
patience: 10
factor: 0.1
save_top_k_models: 10
learning_rate: 3.7e-4
weight_decay: 0.0
Expand Down
3 changes: 2 additions & 1 deletion amplfi/train/configs/flow/sg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ model:
class_path: ml4gw.nn.norm.GroupNorm1DGetter
init_args:
groups: 16
patience: null
patience: 10
factor: 0.1
save_top_k_models: 10
learning_rate: 3.77e-4
weight_decay: 0.0
Expand Down
3 changes: 2 additions & 1 deletion amplfi/train/configs/similarity/cbc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ model:
class_path: ml4gw.nn.norm.GroupNorm1DGetter
init_args:
groups: 16
patience: null
patience: 10
factor: 0.1
save_top_k_models: 10
learning_rate: 3.77e-4
weight_decay: 0.0
Expand Down
11 changes: 9 additions & 2 deletions amplfi/train/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ def __init__(
learning_rate: float,
weight_decay: float = 0.0,
save_top_k_models: int = 10,
patience: Optional[int] = None,
patience: int = 10,
factor: float = 0.1,
checkpoint: Optional[Path] = None,
verbose: bool = False,
):
super().__init__()
self.scheduler_patience = patience
self.scheduler_factor = factor
self._logger = self.init_logging(verbose)
self.outdir = outdir
outdir.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -106,7 +109,11 @@ def configure_optimizers(self):
weight_decay=self.hparams.weight_decay,
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
patience=self.scheduler_patience,
factor=self.scheduler_factor,
)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "valid_loss"},
Expand Down
29 changes: 28 additions & 1 deletion tests/architectures/embeddings/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from amplfi.train.architectures.embeddings import ResNet
from amplfi.train.architectures.embeddings import MultiModal, ResNet
from amplfi.train.architectures.embeddings.dense import DenseEmbedding


Expand Down Expand Up @@ -30,6 +30,16 @@ def out_features(request):
return request.param


@pytest.fixture(params=[32, 64, 128])
def time_out_features(request):
return request.param


@pytest.fixture(params=[32, 64, 128])
def freq_out_features(request):
return request.param


def test_dense_embedding(n_ifos, length):
embedding = DenseEmbedding(length, 10)
x = torch.randn(8, n_ifos, length)
Expand All @@ -43,3 +53,20 @@ def test_resnet(n_ifos, length, out_features, kernel_size):
x = torch.randn(100, n_ifos, length)
y = embedding(x)
assert y.shape == (100, out_features)


def test_multimodal(
n_ifos, length, time_out_features, freq_out_features, kernel_size
):
embedding = MultiModal(
n_ifos,
time_out_features,
freq_out_features,
[3, 3],
[3, 3],
kernel_size,
kernel_size,
)
x = torch.randn(100, n_ifos, length)
y = embedding(x)
assert y.shape == (100, time_out_features + freq_out_features)

0 comments on commit 4d9319c

Please sign in to comment.