Skip to content

Commit

Permalink
decouple context dim from flow (ML4GW#144)
Browse files Browse the repository at this point in the history
* decouple context dim from flow

* fix tests

* add correct pull command to docs
  • Loading branch information
EthanMarx authored Oct 10, 2024
1 parent f2f2cae commit 77fe28d
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 24 deletions.
14 changes: 9 additions & 5 deletions amplfi/train/architectures/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ class Embedding(torch.nn.Module):
"""
Dummy base class for embedding networks.
All embeddings should accept as arguments
`num_ifos`, `context_dim` and `strain_dim` as
their first through third arguments, since the
CLI links to this argument from the datamodule
at initialization time.
All embeddings should accept `num_ifos`
as their first argument. They should also
define a `context_dim` attribute that returns
the dimensionality of the output of the network,
which will be used to instantiate the flow transorms.
This class obvioulsy isn't necessary, but leaving this
as a reminder that we may wan't to enforce
the above behavior more explicitly in the future
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.context_dim = None
6 changes: 5 additions & 1 deletion amplfi/train/architectures/embeddings/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(
width_per_group: int = 64,
stride_type: Optional[list[Literal["stride", "dilation"]]] = None,
norm_layer: Optional[NormLayer] = None,
**kwargs
):

super().__init__(
num_ifos,
layers=layers,
Expand All @@ -31,3 +31,7 @@ def __init__(
stride_type=stride_type,
norm_layer=norm_layer,
)

# set the context dimension so
# the flow can access it
self.context_dim = context_dim
6 changes: 3 additions & 3 deletions amplfi/train/architectures/flows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
from pyro.distributions.conditional import ConditionalComposeTransformModule
from pyro.nn import PyroModule

from amplfi.train.architectures.embeddings.base import Embedding


class FlowArchitecture(PyroModule):
def __init__(
self,
num_params: int,
context_dim: int,
embedding_net: torch.nn.Module,
embedding_net: Embedding,
embedding_weights: Optional[Path] = None,
freeze_embedding: bool = False,
):

super().__init__()
self.num_params = num_params
self.context_dim = context_dim
self.embedding_net = embedding_net

if freeze_embedding:
Expand Down
2 changes: 1 addition & 1 deletion amplfi/train/architectures/flows/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def transform_block(self):
"""Returns single affine coupling transform"""
arn = ConditionalDenseNN(
self.split_dim,
self.context_dim,
self.embedding_net.context_dim,
[self.hidden_features],
param_dims=[
self.num_params - self.split_dim,
Expand Down
2 changes: 1 addition & 1 deletion amplfi/train/architectures/flows/iaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def transform_block(self):
"""Returns single autoregressive transform"""
arn = ConditionalAutoRegressiveNN(
self.num_params,
self.context_dim,
self.embedding_net.context_dim,
self.num_blocks * [self.hidden_features],
nonlinearity=self.activation,
)
Expand Down
6 changes: 0 additions & 6 deletions amplfi/train/cli/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ def add_arguments_to_parser(self, parser):
apply_on="parse",
)

parser.link_arguments(
"model.init_args.arch.init_args.context_dim",
"model.init_args.arch.init_args.embedding_net.init_args.context_dim", # noqa
apply_on="parse",
)

parser.link_arguments(
"data.init_args.ifos",
"model.init_args.arch.init_args.embedding_net.init_args.num_ifos",
Expand Down
2 changes: 1 addition & 1 deletion amplfi/train/configs/flow/cbc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ model:
hidden_features: 150
num_transforms: 80
num_blocks: 6
context_dim: 7
# uncomment below to load
# in pre-trained embedding weights
# embedding_weights: "path/to/embedding/weights"
# freeze_embedding: false
embedding_net:
class_path: amplfi.train.architectures.embeddings.ResNet
init_args:
context_dim: 7
layers: [5, 3, 3]
norm_layer:
class_path: ml4gw.nn.norm.GroupNorm1DGetter
Expand Down
2 changes: 1 addition & 1 deletion amplfi/train/configs/flow/sg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ model:
hidden_features: 150
num_transforms: 80
num_blocks: 6
context_dim: 8
# uncomment below to load
# in pre-trained embedding weights
# embedding_weights: "path/to/embedding/weights"
# freeze_embedding: false
embedding_net:
class_path: amplfi.train.architectures.embeddings.ResNet
init_args:
context_dim: 8
layers: [5, 3, 3]
norm_layer:
class_path: ml4gw.nn.norm.GroupNorm1DGetter
Expand Down
4 changes: 2 additions & 2 deletions docs/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ You can pull the container locally with either docker or apptainer
.. code-block:: console
$ apptainer pull docker://ghcr.io/ml4gw/amplfi/data:main $AMPLFI_CONTAINER_ROOT/amplfi.sif
$ apptainer pull $AFRAME_CONTAINER_ROOT/amplfi.sif docker://ghcr.io/ml4gw/amplfi/amplfi:main
Supported python versions: 3.9-3.12.
.. tab:: docker
.. code-block:: console
$ docker pull ghcr.io/ml4gw/amplfi:main $AMPLFI_CONTAINER_ROOT/amplfi.sif
$ docker pull ghcr.io/ml4gw/amplfi/amplfi:main
Supported python versions: 3.9-3.12.
```
Expand Down
3 changes: 0 additions & 3 deletions tests/architectures/flows/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_coupling_flow(

coupling_flow = CouplingFlow(
param_dim,
context_dim,
embedding,
num_transforms=num_transforms,
)
Expand All @@ -63,7 +62,6 @@ def test_autoregressive_flow(

iaf = InverseAutoregressiveFlow(
param_dim,
context_dim,
embedding,
num_transforms=num_transforms,
)
Expand All @@ -72,7 +70,6 @@ def test_autoregressive_flow(

maf = MaskedAutoregressiveFlow(
param_dim,
context_dim,
embedding,
num_transforms=num_transforms,
)
Expand Down

0 comments on commit 77fe28d

Please sign in to comment.