diff --git a/edward/__init__.py b/edward/__init__.py index ce3305795..833a4300a 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -20,7 +20,7 @@ from edward.util import check_data, check_latent_vars, copy, dot, \ get_ancestors, get_blanket, get_children, get_control_variate_coef, \ get_descendants, get_parents, get_session, get_siblings, get_variables, \ - Progbar, random_variables, rbf, set_seed, to_simplex, transform + marginal, Progbar, random_variables, rbf, set_seed, to_simplex, transform from edward.version import __version__, VERSION from tensorflow.python.util.all_util import remove_undocumented @@ -74,6 +74,7 @@ 'get_session', 'get_siblings', 'get_variables', + 'marginal', 'Progbar', 'random_variables', 'rbf', diff --git a/edward/util/__init__.py b/edward/util/__init__.py index dce454aed..7320d77df 100644 --- a/edward/util/__init__.py +++ b/edward/util/__init__.py @@ -25,6 +25,7 @@ 'get_session', 'get_siblings', 'get_variables', + 'marginal', 'Progbar', 'random_variables', 'rbf', diff --git a/edward/util/random_variables.py b/edward/util/random_variables.py index 5b5f3d137..febb0f41f 100644 --- a/edward/util/random_variables.py +++ b/edward/util/random_variables.py @@ -716,6 +716,69 @@ def get_variables(x, collection=None): return list(output) +def marginal(x, n): + """Performs a full graph sample on the provided random variable. + + Given a random variable and a sample size, adds an additional sample + dimension to the root random variables in x's graph, and samples from + a new graph in terms of that sample size. + + Args: + x : RandomVariable. + Random variable to perform full graph sample on. + n : tf.Tensor or int + The size of the full graph sample to take. + + Returns: + tf.Tensor. + Full graph sample of shape [n] + x.batch_shape + x.event_shape. + + #### Examples + + ```python + ed.get_session() + loc = Normal(0.0, 100.0) + y = Normal(loc, 0.0001) + conditional_sample = y.sample(50) + marginal_sample = ed.marginal(y, 50) + + np.std(conditional_sample.eval()) + 0.000100221 + + np.std(marginal_sample.eval()) + 106.55982 + ``` + + #### Notes + + The current implementation only works for graphs of RVs that don't use + the `sample_shape` kwarg. + """ + ancestors = get_ancestors(x) + if any([rv.sample_shape != () for rv in ancestors]) or x.sample_shape != (): + raise NotImplementedError("`marginal` doesn't support graphs of RVs " + "with non scalar sample_shape args.") + elif ancestors == []: + old_roots = [x] + else: + old_roots = [rv for rv in ancestors if get_ancestors(rv) == []] + + new_roots = [] + for rv in old_roots: + new_rv = copy(rv) + new_rv._sample_shape = tf.TensorShape(n).concatenate(new_rv._sample_shape) + new_rv._value = new_rv.sample(new_rv._sample_shape) + new_roots.append(new_rv) + dict_swap = dict(zip(old_roots, new_roots)) + x_full = copy(x, dict_swap, replace_itself=True) + if x_full.shape[1:] != x.shape: + print(x_full.shape) + print(x.shape) + raise ValueError('Could not transform graph for bulk sampling.') + + return x_full + + def transform(x, *args, **kwargs): """Transform a continuous random variable to the unconstrained space. diff --git a/tests/util/test_marginal.py b/tests/util/test_marginal.py new file mode 100644 index 000000000..d35360692 --- /dev/null +++ b/tests/util/test_marginal.py @@ -0,0 +1,103 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import edward as ed +import numpy as np +import tensorflow as tf + +from edward.models import Normal, InverseGamma + + +class test_marginal_class(tf.test.TestCase): + + def test_bad_graph(self): + with self.test_session(): + loc = Normal(tf.zeros(5), 5.0) + y_loc = tf.expand_dims(loc, 1) # this displaces the sample dimension + inv_scale = Normal(tf.zeros(3), 1.0) + y_scale = tf.expand_dims(tf.nn.softplus(inv_scale), 0) + y = Normal(y_loc, y_scale) + with self.assertRaises(ValueError): + ed.marginal(y, 20) + + def test_sample_arg(self): + with self.test_session(): + y = Normal(0.0, 1.0, sample_shape=10) + with self.assertRaises(NotImplementedError): + ed.marginal(y, 20) + + def test_sample_arg_ancestor(self): + with self.test_session(): + x = Normal(0.0, 1.0, sample_shape=10) + y = Normal(x, 0.0) + with self.assertRaises(NotImplementedError): + ed.marginal(y, 20) + + def test_no_ancestor(self): + with self.test_session(): + y = Normal(0.0, 1.0) + sample = ed.marginal(y, 4) + self.assertEqual(sample.shape, [4]) + + def test_no_ancestor_batch(self): + with self.test_session(): + y = Normal(tf.zeros([2, 3, 4]), 1.0) + sample = ed.marginal(y, 5) + self.assertEqual(sample.shape, [5, 2, 3, 4]) + + def test_single_ancestor(self): + with self.test_session(): + loc = Normal(0.0, 1.0) + y = Normal(loc, 1.0) + sample = ed.marginal(y, 4) + self.assertEqual(sample.shape, [4]) + + def test_single_ancestor_batch(self): + with self.test_session(): + loc = Normal(tf.zeros([2, 3, 4]), 1.0) + y = Normal(loc, 1.0) + sample = ed.marginal(y, 5) + self.assertEqual(sample.shape, [5, 2, 3, 4]) + + def test_sample_passthrough(self): + with self.test_session(): + tf.set_random_seed(1) + loc = Normal(0.0, 100.0) + y = Normal(loc, 0.0001) + conditional_sample = y.sample(50) + marginal_sample = ed.marginal(y, 50) + self.assertTrue(np.std(conditional_sample.eval()) < 1.0) + self.assertTrue(np.std(marginal_sample.eval()) > 1.0) + + def test_multiple_ancestors(self): + with self.test_session(): + loc = Normal(0.0, 1.0) + scale = InverseGamma(1.0, 1.0) + y = Normal(loc, scale) + sample = ed.marginal(y, 4) + self.assertEqual(sample.shape, [4]) + + def test_multiple_ancestors_batch(self): + with self.test_session(): + loc = Normal(tf.zeros(5), 1.0) + scale = InverseGamma(tf.ones(5), 1.0) + y = Normal(loc, scale) + sample = ed.marginal(y, 4) + self.assertEqual(sample.shape, [4, 5]) + + def test_multiple_ancestors_batch_broadcast(self): + with self.test_session(): + loc = Normal(tf.zeros([5, 1]), 1.0) + scale = InverseGamma(tf.ones([1, 6]), 1.0) + y = Normal(loc, scale) + sample = ed.marginal(y, 4) + self.assertEqual(sample.shape, [4, 5, 6]) + + def test_multiple_ancestors_failed_broadcast(self): + with self.test_session(): + loc = Normal(tf.zeros([5, 1]), 1.0) + scale = InverseGamma(tf.ones([6]), 1.0) + y = Normal(loc, scale) + with self.assertRaises(ValueError): + sample = ed.marginal(y, 4)