From 4d9319c8b0c7ec5a018849ea62ad890f3daff601 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Date: Thu, 24 Oct 2024 20:40:13 +0530 Subject: [PATCH] Multimodal ResNet (#143) * add function to return time and frequency domain waveforms * add multimodal resnet * switch from fft to rfft * run pre-commit * add docstring * add prefix freq for frequency domain waveforms * update docstring, add args for time and freq separately * switch to relative import * add test for multimodal embedding * pass patience correctly through the scheduler * remove asds from datamodule.inject output * hardcode patience=20 * remove patience as a class variable * remove asds * add scheduler patience and factor as arguments of the model * remove time_and_frequency_domain_strain --- .../architectures/embeddings/__init__.py | 1 + .../architectures/embeddings/multimodal.py | 70 +++++++++++++++++++ amplfi/train/callbacks.py | 10 +-- amplfi/train/configs/flow/cbc.yaml | 3 +- amplfi/train/configs/flow/sg.yaml | 3 +- amplfi/train/configs/similarity/cbc.yaml | 3 +- amplfi/train/models/base.py | 11 ++- .../embeddings/test_embeddings.py | 29 +++++++- 8 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 amplfi/train/architectures/embeddings/multimodal.py diff --git a/amplfi/train/architectures/embeddings/__init__.py b/amplfi/train/architectures/embeddings/__init__.py index 5e5433b0..bfeb9ed8 100644 --- a/amplfi/train/architectures/embeddings/__init__.py +++ b/amplfi/train/architectures/embeddings/__init__.py @@ -1,3 +1,4 @@ from .dense import CoherentDenseEmbedding, NChannelDenseEmbedding from .flattener import Flattener +from .multimodal import MultiModal from .resnet import ResNet diff --git a/amplfi/train/architectures/embeddings/multimodal.py b/amplfi/train/architectures/embeddings/multimodal.py new file mode 100644 index 00000000..9aedbad9 --- /dev/null +++ b/amplfi/train/architectures/embeddings/multimodal.py @@ -0,0 +1,70 @@ +from typing import Literal, Optional + +import torch +from ml4gw.nn.norm import NormLayer +from ml4gw.nn.resnet.resnet_1d import ResNet1D + +from .base import Embedding + + +class MultiModal(Embedding): + def __init__( + self, + num_ifos: int, + time_context_dim: int, + freq_context_dim: int, + time_layers: list[int], + freq_layers: list[int], + time_kernel_size: int = 3, + freq_kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs + ): + """ + MultiModal embedding network that embeds both time and frequency data. + + We pass the data through their own ResNets defined by their layers + and context dims, then concatenate the output embeddings. + """ + super().__init__() + self.time_domain_resnet = ResNet1D( + in_channels=num_ifos, + layers=time_layers, + classes=time_context_dim, + kernel_size=time_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + self.frequency_domain_resnet = ResNet1D( + in_channels=int(num_ifos * 2), + layers=freq_layers, + classes=freq_context_dim, + kernel_size=freq_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + # set the context dimension so + # the flow can access it + self.context_dim = time_context_dim + freq_context_dim + + def forward(self, X): + time_domain_embedded = self.time_domain_resnet(X) + X_fft = torch.fft.rfft(X) + X_fft = torch.cat((X_fft.real, X_fft.imag), dim=1) + frequency_domain_embedded = self.frequency_domain_resnet(X_fft) + + embedding = torch.concat( + (time_domain_embedded, frequency_domain_embedded), dim=1 + ) + return embedding diff --git a/amplfi/train/callbacks.py b/amplfi/train/callbacks.py index e8525a51..337780bb 100644 --- a/amplfi/train/callbacks.py +++ b/amplfi/train/callbacks.py @@ -36,9 +36,7 @@ def on_train_start(self, trainer, pl_module): X = X.to(device) cross, plus, parameters = datamodule.waveform_sampler.sample(X) - strain, asds, parameters = datamodule.inject( - X, cross, plus, parameters - ) + strain, parameters = datamodule.inject(X, cross, plus, parameters) # save an example validation batch # and parameters to disk @@ -59,7 +57,7 @@ def on_train_start(self, trainer, pl_module): val_parameters = { k: val_parameters[:, i] for i, k in enumerate(keys) } - val_strain, val_asds, val_parameters = datamodule.inject( + val_strain, val_parameters = datamodule.inject( background, val_cross, val_plus, val_parameters ) @@ -72,7 +70,6 @@ def on_train_start(self, trainer, pl_module): with h5py.File(f, "w") as h5file: h5file["strain"] = strain.cpu().numpy() h5file["parameters"] = parameters.cpu().numpy() - h5file["asds"] = asds.cpu().numpy() s3_file.write(f.getvalue()) with s3.open(f"{save_dir}/val-batch.h5", "wb") as s3_file: @@ -80,7 +77,6 @@ def on_train_start(self, trainer, pl_module): with h5py.File(f, "w") as h5file: h5file["strain"] = val_strain.cpu().numpy() h5file["parameters"] = val_parameters.cpu().numpy() - h5file["asds"] = val_asds.cpu().numpy() s3_file.write(f.getvalue()) else: with h5py.File( @@ -88,14 +84,12 @@ def on_train_start(self, trainer, pl_module): ) as f: f["strain"] = strain.cpu().numpy() f["parameters"] = parameters.cpu().numpy() - f["asds"] = asds.cpu().numpy() with h5py.File( os.path.join(save_dir, "val-batch.h5"), "w" ) as f: f["strain"] = val_strain.cpu().numpy() f["parameters"] = val_parameters.cpu().numpy() - f["asds"] = val_asds.cpu().numpy() class SaveAugmentedSimilarityBatch(pl.Callback): diff --git a/amplfi/train/configs/flow/cbc.yaml b/amplfi/train/configs/flow/cbc.yaml index bace06a1..ee054951 100644 --- a/amplfi/train/configs/flow/cbc.yaml +++ b/amplfi/train/configs/flow/cbc.yaml @@ -41,7 +41,8 @@ model: class_path: ml4gw.nn.norm.GroupNorm1DGetter init_args: groups: 8 - patience: null + patience: 10 + factor: 0.1 save_top_k_models: 10 learning_rate: 3.7e-4 weight_decay: 0.0 diff --git a/amplfi/train/configs/flow/sg.yaml b/amplfi/train/configs/flow/sg.yaml index cd00fc3d..0a34b8e1 100644 --- a/amplfi/train/configs/flow/sg.yaml +++ b/amplfi/train/configs/flow/sg.yaml @@ -39,7 +39,8 @@ model: class_path: ml4gw.nn.norm.GroupNorm1DGetter init_args: groups: 16 - patience: null + patience: 10 + factor: 0.1 save_top_k_models: 10 learning_rate: 3.77e-4 weight_decay: 0.0 diff --git a/amplfi/train/configs/similarity/cbc.yaml b/amplfi/train/configs/similarity/cbc.yaml index 3b9990fd..6f5c881b 100644 --- a/amplfi/train/configs/similarity/cbc.yaml +++ b/amplfi/train/configs/similarity/cbc.yaml @@ -30,7 +30,8 @@ model: class_path: ml4gw.nn.norm.GroupNorm1DGetter init_args: groups: 16 - patience: null + patience: 10 + factor: 0.1 save_top_k_models: 10 learning_rate: 3.77e-4 weight_decay: 0.0 diff --git a/amplfi/train/models/base.py b/amplfi/train/models/base.py index ed758045..22516b1c 100644 --- a/amplfi/train/models/base.py +++ b/amplfi/train/models/base.py @@ -36,11 +36,14 @@ def __init__( learning_rate: float, weight_decay: float = 0.0, save_top_k_models: int = 10, - patience: Optional[int] = None, + patience: int = 10, + factor: float = 0.1, checkpoint: Optional[Path] = None, verbose: bool = False, ): super().__init__() + self.scheduler_patience = patience + self.scheduler_factor = factor self._logger = self.init_logging(verbose) self.outdir = outdir outdir.mkdir(exist_ok=True, parents=True) @@ -106,7 +109,11 @@ def configure_optimizers(self): weight_decay=self.hparams.weight_decay, ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + patience=self.scheduler_patience, + factor=self.scheduler_factor, + ) return { "optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "valid_loss"}, diff --git a/tests/architectures/embeddings/test_embeddings.py b/tests/architectures/embeddings/test_embeddings.py index 05508289..88ab4fcc 100644 --- a/tests/architectures/embeddings/test_embeddings.py +++ b/tests/architectures/embeddings/test_embeddings.py @@ -1,7 +1,7 @@ import pytest import torch -from amplfi.train.architectures.embeddings import ResNet +from amplfi.train.architectures.embeddings import MultiModal, ResNet from amplfi.train.architectures.embeddings.dense import DenseEmbedding @@ -30,6 +30,16 @@ def out_features(request): return request.param +@pytest.fixture(params=[32, 64, 128]) +def time_out_features(request): + return request.param + + +@pytest.fixture(params=[32, 64, 128]) +def freq_out_features(request): + return request.param + + def test_dense_embedding(n_ifos, length): embedding = DenseEmbedding(length, 10) x = torch.randn(8, n_ifos, length) @@ -43,3 +53,20 @@ def test_resnet(n_ifos, length, out_features, kernel_size): x = torch.randn(100, n_ifos, length) y = embedding(x) assert y.shape == (100, out_features) + + +def test_multimodal( + n_ifos, length, time_out_features, freq_out_features, kernel_size +): + embedding = MultiModal( + n_ifos, + time_out_features, + freq_out_features, + [3, 3], + [3, 3], + kernel_size, + kernel_size, + ) + x = torch.randn(100, n_ifos, length) + y = embedding(x) + assert y.shape == (100, time_out_features + freq_out_features)