Skip to content

Commit

Permalink
Added ability for priors of transformed distributions to have their p… (
Browse files Browse the repository at this point in the history
  • Loading branch information
hvarfner authored Jul 25, 2024
1 parent c118306 commit 917603c
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 4 deletions.
25 changes: 25 additions & 0 deletions gpytorch/priors/prior.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#!/usr/bin/env python3

from abc import ABC
from typing import Any, Mapping

from torch.distributions import TransformedDistribution
from torch.nn import Module

from ..distributions import Distribution
from .utils import _load_transformed_to_base_dist


TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \
'_transformed' attributes modified, these are just copies of the base attribute. \
Please modify the base attribute (e.g. {}) instead."""


class Prior(Distribution, Module, ABC):
Expand All @@ -25,3 +33,20 @@ def log_prob(self, x):
:rtype: torch.Tensor
"""
return super(Prior, self).log_prob(self.transform(x))

def load_state_dict(self, state_dict: Mapping[str, Any], *args, **kwargs):
Module.load_state_dict(self, state_dict, *args, **kwargs)
if isinstance(self, TransformedDistribution):
_load_transformed_to_base_dist(self)

def __setattr__(self, name: str, value: Any) -> None:
if hasattr(self, name) and "_transformed_" in name:
base_attr_name = name.replace("_transformed_", "")
raise AttributeError(TRANSFORMED_ERROR_MSG.format(base_attr_name))

elif hasattr(self, f"_transformed_{name}"):
self.base_dist.__setattr__(name, value)
super().__setattr__(f"_transformed_{name}", value)

else:
return super().__setattr__(name, value)
3 changes: 3 additions & 0 deletions gpytorch/priors/torch_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class HalfNormalPrior(Prior, HalfNormal):
def __init__(self, scale, validate_args=None, transform=None):
TModule.__init__(self)
HalfNormal.__init__(self, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("scale",))
self._transform = transform

def expand(self, batch_shape):
Expand All @@ -54,6 +55,7 @@ class LogNormalPrior(Prior, LogNormal):
def __init__(self, loc, scale, validate_args=None, transform=None):
TModule.__init__(self)
LogNormal.__init__(self, loc=loc, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("loc", "scale"))
self._transform = transform

def expand(self, batch_shape):
Expand Down Expand Up @@ -84,6 +86,7 @@ class HalfCauchyPrior(Prior, HalfCauchy):
def __init__(self, scale, validate_args=None, transform=None):
TModule.__init__(self)
HalfCauchy.__init__(self, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("scale",))
self._transform = transform

def expand(self, batch_shape):
Expand Down
33 changes: 29 additions & 4 deletions gpytorch/priors/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,36 @@
#!/usr/bin/env python3

from torch.distributions import TransformedDistribution


def _bufferize_attributes(module, attributes):
attr_clones = {attr: getattr(module, attr).clone() for attr in attributes}
for attr, value in attr_clones.items():
delattr(module, attr)
module.register_buffer(attr, value)
r"""
Adds the parameters of the prior as a torch buffer to enable saving/
loading to/from state_dicts.
For TransformedDistributions Adds a _transformed_ attribute to the
parameters. This enables its parameters to be saved and
loaded to/from state_dicts, as the original parameters cannot be.
"""
if isinstance(module, TransformedDistribution):
for attr in attributes:
module.register_buffer(f"_transformed_{attr}", getattr(module, attr))
else:
attr_clones = {attr: getattr(module, attr).clone() for attr in attributes}
for attr, value in attr_clones.items():
delattr(module, attr)
module.register_buffer(attr, value)


def _load_transformed_to_base_dist(module):
r"""loads the _transformed_ attributes to the parameters of a torch
TransformedDistribution. This enables its parameters to be saved and
loaded to/from state_dicts, as the original parameters cannot be.
"""
transf_str = "_transformed_"
transformed_attrs = [attr for attr in dir(module) if transf_str in attr]
for transf_attr in transformed_attrs:
base_attr_name = transf_attr.replace(transf_str, "")
setattr(module.base_dist, base_attr_name, getattr(module, transf_attr))


def _del_attributes(module, attributes, raise_on_error=False):
Expand Down
70 changes: 70 additions & 0 deletions test/priors/test_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3

import unittest

from torch import Tensor

from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior


TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \
'_transformed' attributes modified, these are just copies of the base attribute. \
Please modify the base attribute (e.g. {}) instead."""


class TestPrior(unittest.TestCase):
def test_state_dict(self):
normal = NormalPrior(0.1, 1).state_dict()
self.assertTrue("loc" in normal)
self.assertTrue("scale" in normal)
self.assertEqual(normal["loc"], 0.1)

gamma = GammaPrior(1.1, 2).state_dict()
self.assertTrue("concentration" in gamma)
self.assertTrue("rate" in gamma)
self.assertEqual(gamma["concentration"], 1.1)

ln = LogNormalPrior(2.1, 1.2).state_dict()
self.assertTrue("_transformed_loc" in ln)
self.assertTrue("_transformed_scale" in ln)
self.assertEqual(ln["_transformed_loc"], 2.1)

hc = HalfCauchyPrior(1.3).state_dict()
self.assertTrue("_transformed_scale" in hc)

def test_load_state_dict(self):
ln1 = LogNormalPrior(loc=0.5, scale=0.1)
ln2 = LogNormalPrior(loc=2.5, scale=2.1)
gm1 = GammaPrior(concentration=0.5, rate=0.1)
gm2 = GammaPrior(concentration=2.5, rate=2.1)
hc1 = HalfCauchyPrior(scale=1.1)
hc2 = HalfCauchyPrior(scale=101.1)

ln2.load_state_dict(ln1.state_dict())
self.assertEqual(ln2.loc, ln1.loc)
self.assertEqual(ln2.scale, ln1.scale)

gm2.load_state_dict(gm1.state_dict())
self.assertEqual(gm2.concentration, gm1.concentration)
self.assertEqual(gm2.rate, gm1.rate)

hc2.load_state_dict(hc1.state_dict())
self.assertEqual(hc2.scale, hc1.scale)

def test_transformed_attributes(self):
norm = NormalPrior(loc=2.5, scale=2.1)
ln = LogNormalPrior(loc=2.5, scale=2.1)
hc = HalfCauchyPrior(scale=2.2)

with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"):
getattr(norm, "_transformed_loc")

self.assertTrue(getattr(ln, "_transformed_loc"), 2.5)
norm.loc = Tensor([1.01])
ln.loc = Tensor([1.01])
self.assertEqual(ln._transformed_loc, 1.01)
with self.assertRaises(AttributeError):
ln._transformed_loc = 1.1

with self.assertRaises(AttributeError):
hc._transformed_scale = 1.01
61 changes: 61 additions & 0 deletions test/priors/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

import unittest

from torch import Tensor

from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior


class TestPrior(unittest.TestCase):
def test_state_dict(self):
normal = NormalPrior(0.1, 1).state_dict()
self.assertTrue("loc" in normal)
self.assertTrue("scale" in normal)
self.assertEqual(normal["loc"], 0.1)

gamma = GammaPrior(1.1, 2).state_dict()
self.assertTrue("concentration" in gamma)
self.assertTrue("rate" in gamma)
self.assertEqual(gamma["concentration"], 1.1)

ln = LogNormalPrior(2.1, 1.2).state_dict()
self.assertTrue("_transformed_loc" in ln)
self.assertTrue("_transformed_scale" in ln)
self.assertEqual(ln["_transformed_loc"], 2.1)

hc = HalfCauchyPrior(1.3).state_dict()
self.assertTrue("_transformed_scale" in hc)

def test_load_state_dict(self):
ln1 = LogNormalPrior(loc=0.5, scale=0.1)
ln2 = LogNormalPrior(loc=2.5, scale=2.1)
gm1 = GammaPrior(concentration=0.5, rate=0.1)
gm2 = GammaPrior(concentration=2.5, rate=2.1)
hc1 = HalfCauchyPrior(scale=1.1)
hc2 = HalfCauchyPrior(scale=101.1)

ln2.load_state_dict(ln1.state_dict())
self.assertEqual(ln2.loc, ln1.loc)
self.assertEqual(ln2.scale, ln1.scale)

gm2.load_state_dict(gm1.state_dict())
self.assertEqual(gm2.concentration, gm1.concentration)
self.assertEqual(gm2.rate, gm1.rate)

hc2.load_state_dict(hc1.state_dict())
self.assertEqual(hc2.scale, hc1.scale)

def test_transformed_attributes(self):
norm = NormalPrior(loc=2.5, scale=2.1)
ln = LogNormalPrior(loc=2.5, scale=2.1)
hc = HalfCauchyPrior(scale=2.2)

with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"):
getattr(norm, "_transformed_loc")

self.assertTrue(getattr(ln, "_transformed_loc"), 2.5)
norm.loc = Tensor([1.01])
ln.loc = Tensor([1.01])
self.assertEqual(ln._transformed_loc, 1.01)
self.assertEqual(hc._transformed_scale, 2.2)

0 comments on commit 917603c

Please sign in to comment.