Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New unifying structure for coupling architectures #139

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions FrEIA/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .invertible import Invertible
23 changes: 23 additions & 0 deletions FrEIA/core/invertible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

from abc import ABC
import torch.nn as nn

from typing import Any

from typing import TypeVar

T = TypeVar("T")


class Invertible(ABC, nn.Module):
def forward(self, *args: T, **kwargs: T) -> Any:
raise NotImplementedError

def inverse(self, *args, **kwargs):
raise NotImplementedError

def __call__(self, *args, rev = False, **kwargs):
if not rev:
return self.forward(*args, **kwargs)

return self.inverse(*args, **kwargs)
Empty file added FrEIA/flows/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions FrEIA/flows/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

from freia.core import Invertible

class Flow(Invertible):
def __init__(self, transform, distribution):
self.transform = transform
self.distribution = distribution

def forward(self, x):
z, logdet = self.transform.forward(x)

logp = self.distribution.log_prob(z)

nll = -(logp + logdet)

return z, nll

def sample_transform(self, size, temperature):
z = self.distribution.sample(size, temperature)

x, _ = self.transform.inverse(z)

return x


class RecurrentFlow(Flow):
def forward(self, x):
z = x
logdet = None
for t in range(...):
z, logdet = self.transform.forward(z, t)
2 changes: 2 additions & 0 deletions FrEIA/splits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .even import EvenSplit
18 changes: 18 additions & 0 deletions FrEIA/splits/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

from FrEIA.core import Invertible

from typing import Tuple

import torch


class Split(Invertible):
def __init__(self, dim: int = 1):
super().__init__()
self.dim = dim

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
14 changes: 14 additions & 0 deletions FrEIA/splits/even.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

from .base import Split

from typing import Tuple

import torch


class EvenSplit(Split):
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.split(x, 2, dim=1)

def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return torch.cat((x1, x2), dim=1)
3 changes: 3 additions & 0 deletions FrEIA/splits/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@


# class RandomSplit()
2 changes: 2 additions & 0 deletions FrEIA/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .base import Transform
25 changes: 25 additions & 0 deletions FrEIA/transforms/affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

from .base import Transform

import torch

from .coupling import CouplingTransform


class AffineTransform(CouplingTransform):

def __init__(self):
parameter_counts = {...}
super().__init__(parameter_counts=parameter_counts)

def transform_parameters(self, **parameters):
parameters["a"] = torch.exp(parameters["a"])

def _forward(self, x: torch.Tensor, **parameters) -> torch.Tensor:
parameters = self.get_parameters()
a, b = parameters["a"], parameters["b"]
return a * x + b, torch.log(a)

def _inverse(self, z: torch.Tensor, **parameters) -> torch.Tensor:
a, b = parameters["a"], parameters["b"]
return (z - b) / a, -torch.log(a)
46 changes: 46 additions & 0 deletions FrEIA/transforms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

from freia.core import Invertible

import torch


WithJacobian = tuple[torch.Tensor, torch.Tensor]



class Transform(Invertible):
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def inverse(self, z: torch.Tensor) -> torch.Tensor:
raise NotImplementedError




@Parameterized(scale=1, shift=1)
class AffineTransform(Transform):
def forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
return scale * x + shift

def inverse(self, z: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
return (z - shift) / scale


# class SplineTransform(Transform):
# def forward(self, x: torch.Tensor, edges: torch.Tensor):
# assert edges.shape == (..., self.bins)
# pass


class Parameterized:
def __init__(self, **parameter_counts):
self.parameter_counts = parameter_counts

def __call__(self, cls):

cls.forward = forward
cls.inverse = inverse



73 changes: 73 additions & 0 deletions FrEIA/transforms/coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

from .base import Transform
from freia.splits import EvenSplit

import torch
import torch.nn as nn


class Spline(Transform):
def __init__(self, affine, inner_spline):
...

def forward(self, x: torch.Tensor, *, condition: torch.Tensor, **kwargs) -> WithJacobian:
x[out] = affine(x[out])
x[out] = inner_spline(x[out])


class Spline(CouplingTransform):
def _forward(self):
x[in] = self._spline(...)
x[out] = self._affine(...)



class CouplingTransform(Transform):
def __init__(self, transform1, transform2, subnet_constructor, split=EvenSplit(dim=1)):
self.split = split
self.subnet1 = subnet_constructor(...)
self.subnet2 = subnet_constructor(...)

def split_parameters(self, parameters: torch.Tensor) -> dict:
pc = self.parameter_counts
parameters = torch.split(parameters, list(pc.values()), dim=1)

return dict(zip(pc.keys(), parameters))

def transform_parameters(self, parameters: dict[torch.Tensor]) -> None:
pass

def get_parameters(self, *args, **kwargs) -> dict:
raise NotImplementedError

def get_parameters(self, u: torch.Tensor, subnet: nn.Module) -> dict:

parameters = subnet(u)
parameters = self.split_parameters(parameters)
should_be_none = self.transform_parameters(**parameters)
if should_be_none is not None:
warnings.warn(...)

return parameters


def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> torch.Tensor:
x1, x2 = self.split.forward(x)



parameters = self.get_parameters(u=x2, subnet=self.subnet1)
z1, logdet1 = self.transform1.forward(x1, **parameters)
parameters = self.get_parameters(u=z1, subnet=self.subnet2)
z2, logdet2 = self.transform2(x2, **parameters)

z = self.split.inverse(z1, z2)
logdet = logdet1 + logdet2

return z, logdet





my_single_coupling = CouplingTransform(transform1=AffineTransform(...), transform2=None)
12 changes: 12 additions & 0 deletions FrEIA/transforms/identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

from .base import Transform

import torch


class IdentityTransform(Transform):
def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian:
return x, 0

def inverse(self, z: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian:
return z, 0
59 changes: 59 additions & 0 deletions FrEIA/transforms/ode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

from .base import Transform

import torch

from scipy.ode import solve_ode



class Parameterized(nn.Module):
def __init__(self, *, subnet_constructor, parameter_counts):
super().__init__()
self.subnet = ...
self.parameter_counts = ...
self.transform = transform_cls

def __call__(self, *args, **kwargs):
self.transform = transform_cls(*args, **kwargs)

return self

def forward(self):
parameters = self.subnet(...)
return self.transform(x, parameters)


@Parameterized
class ODETransform(Transform):
def __init__(self, integration_steps: int = 10):
super().__init__()
self.integration_steps = integration_steps

def forward(self, x: torch.Tensor, **parameters) -> tuple[torch.Tensor, torch.Tensor]:
return euler(x, v, dt)

# ode integration
dt = 1 / self.integration_steps
for _ in range(self.integration_steps):
parameters = self.get_parameters()
v = parameters["v"]
x = euler(x, v, dt)

return x

ODETransform = Parameterized(ODETransform)





ode = ODETransform()




def euler(x, v, dt):
return x + v * dt


Loading