Skip to content

Commit

Permalink
Merge pull request #1 from rs-station/modules
Browse files Browse the repository at this point in the history
Add Distribution Modules and Transformed Parameters
  • Loading branch information
kmdalton authored Apr 16, 2024
2 parents 715eadc + 541803a commit f0cc884
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/rs_distributions/distributions/folded_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class FoldedNormal(dist.Distribution):
"""

arg_constraints = {"loc": dist.constraints.real, "scale": dist.constraints.positive}
support = torch.distributions.constraints.nonnegative

def __init__(self, loc, scale, validate_args=None):
self.loc = torch.as_tensor(loc)
Expand All @@ -50,6 +51,8 @@ def log_prob(self, value):
Returns:
Tensor: The log-probabilities of the given values
"""
if self._validate_args:
self._validate_sample(value)
loc = self.loc
scale = self.scale
log_prob = torch.logaddexp(
Expand Down Expand Up @@ -109,6 +112,8 @@ def cdf(self, value):
Returns:
Tensor: The CDF values at the given values
"""
if self._validate_args:
self._validate_sample(value)
value = torch.as_tensor(value, dtype=self.loc.dtype, device=self.loc.device)
# return dist.Normal(loc, scale).cdf(value) - dist.Normal(-loc, scale).cdf(-value)
return 0.5 * (
Expand Down
12 changes: 12 additions & 0 deletions src/rs_distributions/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .transformed_parameter import TransformedParameter # noqa
from .distribution import DistributionModule
from .distribution import * # noqa
from .kl import kl_divergence # noqa
from .distribution import __all__ as all_distributions

__all__ = [
"TransformedParameter",
"DistributionModule",
"kl_divergence",
]
__all__.extend(all_distributions)
136 changes: 136 additions & 0 deletions src/rs_distributions/modules/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
from rs_distributions.modules import TransformedParameter
from rs_distributions import distributions as rsd
from inspect import signature
from functools import wraps


# TODO: decide whether to use "ignore" or "include" pattern here
# Distributions which are currently not supported
ignored_distributions = (
"Uniform", # has weird "dependent" constraints
"Binomial", # has_rsample == False
"MixtureSameFamily", # has_rsample == False
)


class DistributionModule(torch.nn.Module):
"""
Base class for constructing learnable distributions.
This subclass of `torch.nn.Module` acts like a `torch.distributions.Distribution`
object with learnable `torch.nn.Parameter` attributes.
It works by lazily constructing distributions as needed.
Here is a simple example of distribution matching using learnable distributions with reparameterized gradients.
```python
from rs_distributions import modules as rsm
import torch
q = rsm.FoldedNormal(10., 5.)
p = torch.distributions.HalfNormal(1.)
opt = torch.optim.Adam(q.parameters())
steps = 10_000
num_samples = 256
for i in range(steps):
opt.zero_grad()
z = q.rsample((num_samples,))
kl = (q.log_prob(z) - p.log_prob(z)).mean()
kl.backward()
opt.step()
```
"""

def __init__(self, distribution_class, *args, **kwargs):
super().__init__()
self.distribution_class = distribution_class
sig = signature(distribution_class)
bargs = sig.bind(*args, **kwargs)
bargs.apply_defaults()
for arg in distribution_class.arg_constraints:
param = bargs.arguments.pop(arg)
param = self._constrain_arg_if_needed(arg, param)
setattr(self, f"_transformed_{arg}", param)
self._extra_args = bargs.arguments

def __repr__(self):
rstring = super().__repr__().split("\n")[1:]
rstring = [str(self.distribution_class) + " DistributionModule("] + rstring
return "\n".join(rstring)

def _distribution(self):
kwargs = {
k: self._realize_parameter(getattr(self, f"_transformed_{k}"))
for k in self.distribution_class.arg_constraints
}
kwargs.update(self._extra_args)
return self.distribution_class(**kwargs)

def _constrain_arg_if_needed(self, name, value):
if isinstance(value, TransformedParameter):
return value
cons = self.distribution_class.arg_constraints[name]
if cons == torch.distributions.constraints.dependent:
transform = torch.distributions.AffineTransform(0.0, 1.0)
else:
transform = torch.distributions.constraint_registry.transform_to(cons)
return TransformedParameter(value, transform)

@staticmethod
def _realize_parameter(param):
if isinstance(param, TransformedParameter):
return param()
return param

def __getattr__(self, name: str):
if name in self.distribution_class.arg_constraints or hasattr(
self.distribution_class, name
):
q = self._distribution()
return getattr(q, name)
return super().__getattr__(name)

@classmethod
def generate_subclass(cls, distribution_class):
class DistributionModuleSubclass(cls):
__doc__ = distribution_class.__doc__
arg_constraints = distribution_class.arg_constraints

@wraps(distribution_class.__init__)
def __init__(self, *args, **kwargs):
super().__init__(distribution_class, *args, **kwargs)

return DistributionModuleSubclass

@staticmethod
def _extract_distributions(*modules, base_class=torch.distributions.Distribution):
"""
extract all torch.distributions.Distribution subclasses from a module(s)
into a dict {name: cls}
"""
d = {}
for module in modules:
for k in module.__all__:
distribution_class = getattr(module, k)
if not hasattr(distribution_class, "arg_constraints"):
continue
if not hasattr(distribution_class.arg_constraints, "items"):
continue
if issubclass(distribution_class, base_class):
d[k] = distribution_class
return d


distributions_to_transform = DistributionModule._extract_distributions(
torch.distributions,
rsd,
)

for k in ignored_distributions:
del distributions_to_transform[k]

__all__ = ["DistributionModule"]
for k, v in distributions_to_transform.items():
globals()[k] = DistributionModule.generate_subclass(v)
__all__.append(k)
12 changes: 12 additions & 0 deletions src/rs_distributions/modules/kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from functools import wraps
from rs_distributions import modules as rsm
import torch


@wraps(torch.distributions.kl.kl_divergence)
def kl_divergence(p, q):
if isinstance(p, rsm.DistributionModule):
p = p._distribution()
if isinstance(q, rsm.DistributionModule):
q = q._distribution()
return torch.distributions.kl.kl_divergence(p, q)
27 changes: 27 additions & 0 deletions src/rs_distributions/modules/transformed_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch


class TransformedParameter(torch.nn.Module):
"""
A `torch.nn.Module` subclass representing a constrained variabled.
"""

def __init__(self, value, transform):
"""
Args:
value : Tensor
The initial value of this learnable parameter
transform : torch.distributions.Transform
A transform instance which is applied to the underlying, unconstrained value
"""
super().__init__()
value = torch.as_tensor(value) # support floats
if isinstance(value, torch.nn.Parameter):
self._value = value
value.data = transform.inv(value)
else:
self._value = torch.nn.Parameter(transform.inv(value))
self.transform = transform

def forward(self):
return self.transform(self._value)
75 changes: 75 additions & 0 deletions tests/modules/test_distribution_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from rs_distributions import modules as rsm
import torch

distribution_classes = rsm.DistributionModule._extract_distributions(
rsm, base_class=rsm.DistributionModule
)

# It is common to have arguments which are equivalent and mutually exclusive
# in distribution classes.
exlusive_args = [
("logits", "probs"),
("covariance_matrix", "precision_matrix", "scale_tril"),
]

# Workarounds for distributions with additional, non-parameter arguments
special_kwargs = {
"RelaxedBernoulli": {"temperature": torch.ones(())},
"RelaxedOneHotCategorical": {"temperature": torch.ones(())},
"TransformedDistribution": {
"base_distribution": rsm.Normal(0.0, 1.0),
"transforms": torch.distributions.AffineTransform(0.0, 1.0),
},
"LKJCholesky": {"dim": 3},
"Independent": {
"base_distribution": rsm.Normal(torch.zeros(3), torch.ones(3)),
"reinterpreted_batch_ndims": 1,
},
}


@pytest.mark.parametrize("distribution_class_name", distribution_classes.keys())
def test_distribution_module(distribution_class_name):
distribution_class = distribution_classes[distribution_class_name]
shape = (3, 3)
kwargs = {}
cons = distribution_class.arg_constraints
for group in exlusive_args:
matches_group = all([g in cons for g in group])
if matches_group:
for con in group[1:]:
del cons[con]
for k, v in cons.items():
try:
t = torch.distributions.constraint_registry.transform_to(v)
kwargs[k] = t(torch.ones(shape))
except NotImplementedError:
t = torch.distributions.AffineTransform(0.0, 1.0)
kwargs[k] = rsm.TransformedParameter(v, t)

if distribution_class_name in special_kwargs:
kwargs.update(special_kwargs[distribution_class_name])
q = distribution_class(**kwargs)

# Not all distributions have these attributes implemented
try:
q.mean
q.variance
q.stddev
except NotImplementedError:
pass

if q.has_rsample:
z = q.rsample()
else:
z = q.sample()

ll = q.log_prob(z)

params = list(q.parameters())
if q.has_rsample:
loss = -ll.sum()
loss.backward()
for x in params:
assert torch.isfinite(x.grad).all()
12 changes: 12 additions & 0 deletions tests/modules/test_kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from rs_distributions import modules as rsm
import torch


def test_kl_divergence():
q = rsm.Normal(0.0, 1.0)
p = torch.distributions.Normal(0.0, 1.0)

assert all([param.grad is None for param in q.parameters()])
kl = rsm.kl_divergence(q, p)
kl.backward()
assert all([torch.isfinite(param.grad) for param in q.parameters()])
25 changes: 25 additions & 0 deletions tests/modules/test_transformed_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from rs_distributions.modules import TransformedParameter
import torch


@pytest.mark.parametrize("shape", [(), 10])
def test_transformed_parameter(shape):
value = 10.0
eps = 1e-6
transform = torch.distributions.ComposeTransform(
[
torch.distributions.AffineTransform(eps, 1.0),
torch.distributions.ExpTransform(),
]
)
variable = TransformedParameter(value, transform)
assert variable() == value

params = list(variable.parameters())
assert len(params) == 1

loss = variable().square().sum()
loss.backward()
assert params[0].grad.isfinite().all()
assert (params[0] != 0).all()

0 comments on commit f0cc884

Please sign in to comment.