From 77fe28d389e31416848a0c2143503c1d5fce75a0 Mon Sep 17 00:00:00 2001 From: Ethan Marx <61295922+EthanMarx@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:42:47 -0400 Subject: [PATCH] decouple context dim from flow (#144) * decouple context dim from flow * fix tests * add correct pull command to docs --- amplfi/train/architectures/embeddings/base.py | 14 +++++++++----- amplfi/train/architectures/embeddings/resnet.py | 6 +++++- amplfi/train/architectures/flows/base.py | 6 +++--- amplfi/train/architectures/flows/coupling.py | 2 +- amplfi/train/architectures/flows/iaf.py | 2 +- amplfi/train/cli/flow.py | 6 ------ amplfi/train/configs/flow/cbc.yaml | 2 +- amplfi/train/configs/flow/sg.yaml | 2 +- docs/containers.md | 4 ++-- tests/architectures/flows/test_flows.py | 3 --- 10 files changed, 23 insertions(+), 24 deletions(-) diff --git a/amplfi/train/architectures/embeddings/base.py b/amplfi/train/architectures/embeddings/base.py index 6cd96904..ccf3094d 100644 --- a/amplfi/train/architectures/embeddings/base.py +++ b/amplfi/train/architectures/embeddings/base.py @@ -5,13 +5,17 @@ class Embedding(torch.nn.Module): """ Dummy base class for embedding networks. - All embeddings should accept as arguments - `num_ifos`, `context_dim` and `strain_dim` as - their first through third arguments, since the - CLI links to this argument from the datamodule - at initialization time. + All embeddings should accept `num_ifos` + as their first argument. They should also + define a `context_dim` attribute that returns + the dimensionality of the output of the network, + which will be used to instantiate the flow transorms. This class obvioulsy isn't necessary, but leaving this as a reminder that we may wan't to enforce the above behavior more explicitly in the future """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.context_dim = None diff --git a/amplfi/train/architectures/embeddings/resnet.py b/amplfi/train/architectures/embeddings/resnet.py index dcc4c4d2..fba10708 100644 --- a/amplfi/train/architectures/embeddings/resnet.py +++ b/amplfi/train/architectures/embeddings/resnet.py @@ -18,8 +18,8 @@ def __init__( width_per_group: int = 64, stride_type: Optional[list[Literal["stride", "dilation"]]] = None, norm_layer: Optional[NormLayer] = None, - **kwargs ): + super().__init__( num_ifos, layers=layers, @@ -31,3 +31,7 @@ def __init__( stride_type=stride_type, norm_layer=norm_layer, ) + + # set the context dimension so + # the flow can access it + self.context_dim = context_dim diff --git a/amplfi/train/architectures/flows/base.py b/amplfi/train/architectures/flows/base.py index 97fbd353..8b99e1d2 100644 --- a/amplfi/train/architectures/flows/base.py +++ b/amplfi/train/architectures/flows/base.py @@ -7,20 +7,20 @@ from pyro.distributions.conditional import ConditionalComposeTransformModule from pyro.nn import PyroModule +from amplfi.train.architectures.embeddings.base import Embedding + class FlowArchitecture(PyroModule): def __init__( self, num_params: int, - context_dim: int, - embedding_net: torch.nn.Module, + embedding_net: Embedding, embedding_weights: Optional[Path] = None, freeze_embedding: bool = False, ): super().__init__() self.num_params = num_params - self.context_dim = context_dim self.embedding_net = embedding_net if freeze_embedding: diff --git a/amplfi/train/architectures/flows/coupling.py b/amplfi/train/architectures/flows/coupling.py index 523aae9a..f10aba69 100644 --- a/amplfi/train/architectures/flows/coupling.py +++ b/amplfi/train/architectures/flows/coupling.py @@ -35,7 +35,7 @@ def transform_block(self): """Returns single affine coupling transform""" arn = ConditionalDenseNN( self.split_dim, - self.context_dim, + self.embedding_net.context_dim, [self.hidden_features], param_dims=[ self.num_params - self.split_dim, diff --git a/amplfi/train/architectures/flows/iaf.py b/amplfi/train/architectures/flows/iaf.py index 699c4642..653c2c86 100644 --- a/amplfi/train/architectures/flows/iaf.py +++ b/amplfi/train/architectures/flows/iaf.py @@ -36,7 +36,7 @@ def transform_block(self): """Returns single autoregressive transform""" arn = ConditionalAutoRegressiveNN( self.num_params, - self.context_dim, + self.embedding_net.context_dim, self.num_blocks * [self.hidden_features], nonlinearity=self.activation, ) diff --git a/amplfi/train/cli/flow.py b/amplfi/train/cli/flow.py index 1abf33af..ea701700 100644 --- a/amplfi/train/cli/flow.py +++ b/amplfi/train/cli/flow.py @@ -13,12 +13,6 @@ def add_arguments_to_parser(self, parser): apply_on="parse", ) - parser.link_arguments( - "model.init_args.arch.init_args.context_dim", - "model.init_args.arch.init_args.embedding_net.init_args.context_dim", # noqa - apply_on="parse", - ) - parser.link_arguments( "data.init_args.ifos", "model.init_args.arch.init_args.embedding_net.init_args.num_ifos", diff --git a/amplfi/train/configs/flow/cbc.yaml b/amplfi/train/configs/flow/cbc.yaml index 42811d24..bace06a1 100644 --- a/amplfi/train/configs/flow/cbc.yaml +++ b/amplfi/train/configs/flow/cbc.yaml @@ -28,7 +28,6 @@ model: hidden_features: 150 num_transforms: 80 num_blocks: 6 - context_dim: 7 # uncomment below to load # in pre-trained embedding weights # embedding_weights: "path/to/embedding/weights" @@ -36,6 +35,7 @@ model: embedding_net: class_path: amplfi.train.architectures.embeddings.ResNet init_args: + context_dim: 7 layers: [5, 3, 3] norm_layer: class_path: ml4gw.nn.norm.GroupNorm1DGetter diff --git a/amplfi/train/configs/flow/sg.yaml b/amplfi/train/configs/flow/sg.yaml index 53a415b9..cd00fc3d 100644 --- a/amplfi/train/configs/flow/sg.yaml +++ b/amplfi/train/configs/flow/sg.yaml @@ -26,7 +26,6 @@ model: hidden_features: 150 num_transforms: 80 num_blocks: 6 - context_dim: 8 # uncomment below to load # in pre-trained embedding weights # embedding_weights: "path/to/embedding/weights" @@ -34,6 +33,7 @@ model: embedding_net: class_path: amplfi.train.architectures.embeddings.ResNet init_args: + context_dim: 8 layers: [5, 3, 3] norm_layer: class_path: ml4gw.nn.norm.GroupNorm1DGetter diff --git a/docs/containers.md b/docs/containers.md index c6ff3083..20afd42b 100644 --- a/docs/containers.md +++ b/docs/containers.md @@ -11,7 +11,7 @@ You can pull the container locally with either docker or apptainer .. code-block:: console - $ apptainer pull docker://ghcr.io/ml4gw/amplfi/data:main $AMPLFI_CONTAINER_ROOT/amplfi.sif + $ apptainer pull $AFRAME_CONTAINER_ROOT/amplfi.sif docker://ghcr.io/ml4gw/amplfi/amplfi:main Supported python versions: 3.9-3.12. @@ -19,7 +19,7 @@ You can pull the container locally with either docker or apptainer .. code-block:: console - $ docker pull ghcr.io/ml4gw/amplfi:main $AMPLFI_CONTAINER_ROOT/amplfi.sif + $ docker pull ghcr.io/ml4gw/amplfi/amplfi:main Supported python versions: 3.9-3.12. ``` diff --git a/tests/architectures/flows/test_flows.py b/tests/architectures/flows/test_flows.py index b561368d..397cda28 100644 --- a/tests/architectures/flows/test_flows.py +++ b/tests/architectures/flows/test_flows.py @@ -44,7 +44,6 @@ def test_coupling_flow( coupling_flow = CouplingFlow( param_dim, - context_dim, embedding, num_transforms=num_transforms, ) @@ -63,7 +62,6 @@ def test_autoregressive_flow( iaf = InverseAutoregressiveFlow( param_dim, - context_dim, embedding, num_transforms=num_transforms, ) @@ -72,7 +70,6 @@ def test_autoregressive_flow( maf = MaskedAutoregressiveFlow( param_dim, - context_dim, embedding, num_transforms=num_transforms, )