-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from rs-station/modules
Add Distribution Modules and Transformed Parameters
- Loading branch information
Showing
8 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from .transformed_parameter import TransformedParameter # noqa | ||
from .distribution import DistributionModule | ||
from .distribution import * # noqa | ||
from .kl import kl_divergence # noqa | ||
from .distribution import __all__ as all_distributions | ||
|
||
__all__ = [ | ||
"TransformedParameter", | ||
"DistributionModule", | ||
"kl_divergence", | ||
] | ||
__all__.extend(all_distributions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch | ||
from rs_distributions.modules import TransformedParameter | ||
from rs_distributions import distributions as rsd | ||
from inspect import signature | ||
from functools import wraps | ||
|
||
|
||
# TODO: decide whether to use "ignore" or "include" pattern here | ||
# Distributions which are currently not supported | ||
ignored_distributions = ( | ||
"Uniform", # has weird "dependent" constraints | ||
"Binomial", # has_rsample == False | ||
"MixtureSameFamily", # has_rsample == False | ||
) | ||
|
||
|
||
class DistributionModule(torch.nn.Module): | ||
""" | ||
Base class for constructing learnable distributions. | ||
This subclass of `torch.nn.Module` acts like a `torch.distributions.Distribution` | ||
object with learnable `torch.nn.Parameter` attributes. | ||
It works by lazily constructing distributions as needed. | ||
Here is a simple example of distribution matching using learnable distributions with reparameterized gradients. | ||
```python | ||
from rs_distributions import modules as rsm | ||
import torch | ||
q = rsm.FoldedNormal(10., 5.) | ||
p = torch.distributions.HalfNormal(1.) | ||
opt = torch.optim.Adam(q.parameters()) | ||
steps = 10_000 | ||
num_samples = 256 | ||
for i in range(steps): | ||
opt.zero_grad() | ||
z = q.rsample((num_samples,)) | ||
kl = (q.log_prob(z) - p.log_prob(z)).mean() | ||
kl.backward() | ||
opt.step() | ||
``` | ||
""" | ||
|
||
def __init__(self, distribution_class, *args, **kwargs): | ||
super().__init__() | ||
self.distribution_class = distribution_class | ||
sig = signature(distribution_class) | ||
bargs = sig.bind(*args, **kwargs) | ||
bargs.apply_defaults() | ||
for arg in distribution_class.arg_constraints: | ||
param = bargs.arguments.pop(arg) | ||
param = self._constrain_arg_if_needed(arg, param) | ||
setattr(self, f"_transformed_{arg}", param) | ||
self._extra_args = bargs.arguments | ||
|
||
def __repr__(self): | ||
rstring = super().__repr__().split("\n")[1:] | ||
rstring = [str(self.distribution_class) + " DistributionModule("] + rstring | ||
return "\n".join(rstring) | ||
|
||
def _distribution(self): | ||
kwargs = { | ||
k: self._realize_parameter(getattr(self, f"_transformed_{k}")) | ||
for k in self.distribution_class.arg_constraints | ||
} | ||
kwargs.update(self._extra_args) | ||
return self.distribution_class(**kwargs) | ||
|
||
def _constrain_arg_if_needed(self, name, value): | ||
if isinstance(value, TransformedParameter): | ||
return value | ||
cons = self.distribution_class.arg_constraints[name] | ||
if cons == torch.distributions.constraints.dependent: | ||
transform = torch.distributions.AffineTransform(0.0, 1.0) | ||
else: | ||
transform = torch.distributions.constraint_registry.transform_to(cons) | ||
return TransformedParameter(value, transform) | ||
|
||
@staticmethod | ||
def _realize_parameter(param): | ||
if isinstance(param, TransformedParameter): | ||
return param() | ||
return param | ||
|
||
def __getattr__(self, name: str): | ||
if name in self.distribution_class.arg_constraints or hasattr( | ||
self.distribution_class, name | ||
): | ||
q = self._distribution() | ||
return getattr(q, name) | ||
return super().__getattr__(name) | ||
|
||
@classmethod | ||
def generate_subclass(cls, distribution_class): | ||
class DistributionModuleSubclass(cls): | ||
__doc__ = distribution_class.__doc__ | ||
arg_constraints = distribution_class.arg_constraints | ||
|
||
@wraps(distribution_class.__init__) | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(distribution_class, *args, **kwargs) | ||
|
||
return DistributionModuleSubclass | ||
|
||
@staticmethod | ||
def _extract_distributions(*modules, base_class=torch.distributions.Distribution): | ||
""" | ||
extract all torch.distributions.Distribution subclasses from a module(s) | ||
into a dict {name: cls} | ||
""" | ||
d = {} | ||
for module in modules: | ||
for k in module.__all__: | ||
distribution_class = getattr(module, k) | ||
if not hasattr(distribution_class, "arg_constraints"): | ||
continue | ||
if not hasattr(distribution_class.arg_constraints, "items"): | ||
continue | ||
if issubclass(distribution_class, base_class): | ||
d[k] = distribution_class | ||
return d | ||
|
||
|
||
distributions_to_transform = DistributionModule._extract_distributions( | ||
torch.distributions, | ||
rsd, | ||
) | ||
|
||
for k in ignored_distributions: | ||
del distributions_to_transform[k] | ||
|
||
__all__ = ["DistributionModule"] | ||
for k, v in distributions_to_transform.items(): | ||
globals()[k] = DistributionModule.generate_subclass(v) | ||
__all__.append(k) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from functools import wraps | ||
from rs_distributions import modules as rsm | ||
import torch | ||
|
||
|
||
@wraps(torch.distributions.kl.kl_divergence) | ||
def kl_divergence(p, q): | ||
if isinstance(p, rsm.DistributionModule): | ||
p = p._distribution() | ||
if isinstance(q, rsm.DistributionModule): | ||
q = q._distribution() | ||
return torch.distributions.kl.kl_divergence(p, q) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch | ||
|
||
|
||
class TransformedParameter(torch.nn.Module): | ||
""" | ||
A `torch.nn.Module` subclass representing a constrained variabled. | ||
""" | ||
|
||
def __init__(self, value, transform): | ||
""" | ||
Args: | ||
value : Tensor | ||
The initial value of this learnable parameter | ||
transform : torch.distributions.Transform | ||
A transform instance which is applied to the underlying, unconstrained value | ||
""" | ||
super().__init__() | ||
value = torch.as_tensor(value) # support floats | ||
if isinstance(value, torch.nn.Parameter): | ||
self._value = value | ||
value.data = transform.inv(value) | ||
else: | ||
self._value = torch.nn.Parameter(transform.inv(value)) | ||
self.transform = transform | ||
|
||
def forward(self): | ||
return self.transform(self._value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import pytest | ||
from rs_distributions import modules as rsm | ||
import torch | ||
|
||
distribution_classes = rsm.DistributionModule._extract_distributions( | ||
rsm, base_class=rsm.DistributionModule | ||
) | ||
|
||
# It is common to have arguments which are equivalent and mutually exclusive | ||
# in distribution classes. | ||
exlusive_args = [ | ||
("logits", "probs"), | ||
("covariance_matrix", "precision_matrix", "scale_tril"), | ||
] | ||
|
||
# Workarounds for distributions with additional, non-parameter arguments | ||
special_kwargs = { | ||
"RelaxedBernoulli": {"temperature": torch.ones(())}, | ||
"RelaxedOneHotCategorical": {"temperature": torch.ones(())}, | ||
"TransformedDistribution": { | ||
"base_distribution": rsm.Normal(0.0, 1.0), | ||
"transforms": torch.distributions.AffineTransform(0.0, 1.0), | ||
}, | ||
"LKJCholesky": {"dim": 3}, | ||
"Independent": { | ||
"base_distribution": rsm.Normal(torch.zeros(3), torch.ones(3)), | ||
"reinterpreted_batch_ndims": 1, | ||
}, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("distribution_class_name", distribution_classes.keys()) | ||
def test_distribution_module(distribution_class_name): | ||
distribution_class = distribution_classes[distribution_class_name] | ||
shape = (3, 3) | ||
kwargs = {} | ||
cons = distribution_class.arg_constraints | ||
for group in exlusive_args: | ||
matches_group = all([g in cons for g in group]) | ||
if matches_group: | ||
for con in group[1:]: | ||
del cons[con] | ||
for k, v in cons.items(): | ||
try: | ||
t = torch.distributions.constraint_registry.transform_to(v) | ||
kwargs[k] = t(torch.ones(shape)) | ||
except NotImplementedError: | ||
t = torch.distributions.AffineTransform(0.0, 1.0) | ||
kwargs[k] = rsm.TransformedParameter(v, t) | ||
|
||
if distribution_class_name in special_kwargs: | ||
kwargs.update(special_kwargs[distribution_class_name]) | ||
q = distribution_class(**kwargs) | ||
|
||
# Not all distributions have these attributes implemented | ||
try: | ||
q.mean | ||
q.variance | ||
q.stddev | ||
except NotImplementedError: | ||
pass | ||
|
||
if q.has_rsample: | ||
z = q.rsample() | ||
else: | ||
z = q.sample() | ||
|
||
ll = q.log_prob(z) | ||
|
||
params = list(q.parameters()) | ||
if q.has_rsample: | ||
loss = -ll.sum() | ||
loss.backward() | ||
for x in params: | ||
assert torch.isfinite(x.grad).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from rs_distributions import modules as rsm | ||
import torch | ||
|
||
|
||
def test_kl_divergence(): | ||
q = rsm.Normal(0.0, 1.0) | ||
p = torch.distributions.Normal(0.0, 1.0) | ||
|
||
assert all([param.grad is None for param in q.parameters()]) | ||
kl = rsm.kl_divergence(q, p) | ||
kl.backward() | ||
assert all([torch.isfinite(param.grad) for param in q.parameters()]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
from rs_distributions.modules import TransformedParameter | ||
import torch | ||
|
||
|
||
@pytest.mark.parametrize("shape", [(), 10]) | ||
def test_transformed_parameter(shape): | ||
value = 10.0 | ||
eps = 1e-6 | ||
transform = torch.distributions.ComposeTransform( | ||
[ | ||
torch.distributions.AffineTransform(eps, 1.0), | ||
torch.distributions.ExpTransform(), | ||
] | ||
) | ||
variable = TransformedParameter(value, transform) | ||
assert variable() == value | ||
|
||
params = list(variable.parameters()) | ||
assert len(params) == 1 | ||
|
||
loss = variable().square().sum() | ||
loss.backward() | ||
assert params[0].grad.isfinite().all() | ||
assert (params[0] != 0).all() |