-
Notifications
You must be signed in to change notification settings - Fork 757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Marginal Util Function #778
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work so far.
I do wonder if we should implement this at a slightly higher level, as what you're returning is not the original random variable but its marginal distribution.
If x_full
represents the marginal distribution of x
, it's no longer x
's original distribution. So maybe this warrants defining a Marginal
rv class that takes x
as input, and where the currently implemented function would be _sample_n
. It would be called with a default setting of n
(and cache its output to a class member) when the user calls methods such as _log_prob
.
edward/util/random_variables.py
Outdated
@@ -778,3 +778,66 @@ def transform(x, *args, **kwargs): | |||
new_x = TransformedDistribution(x, bij, *args, **kwargs) | |||
new_x.support = new_support | |||
return new_x | |||
|
|||
|
|||
def marginal(x, n): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions should be placed according to alphabetical ordering of function names.
edward/util/random_variables.py
Outdated
|
||
Returns: | ||
tf.Tensor. | ||
The fully sampled values from x, of shape [n] + x.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x.shape = x.sample_shape + x.batch_shape + x.event_shape
. You replace the sample_shape, so the output should have shape [n] + x.batch_shape + x.event_shape
. But I guess it currently fails if x.sample_shape
is non-scalar anyways.
new_roots = [] | ||
for rv in old_roots: | ||
new_rv = copy(rv) | ||
new_rv._sample_shape = tf.TensorShape(n).concatenate(new_rv._sample_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.TensorShape()
fails if n
is a tf.Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also came up when I looked into #774, and I think would need to be solved at the same time. Sample shape needs a TensorShape, and there's no nice way to turn tensor n
into one (I don't think).
So I think either sample_shape
needs to be interpreted as 'whatever gets passed to sample' and therefore is stored as a tensor, not a tensorshape. This would solve #774, and I don't think would break much. Another alternative is having sample_shape
and sample_shape_tensor
attributes built from the actual tensor representation of the RV.
Let me know if you'd prefer this is implemented in the same PR, I'll push the other changes.
tests/util/test_marginal.py
Outdated
import tensorflow as tf | ||
|
||
from edward.models import Normal, InverseGamma | ||
from tensorflow.contrib.distributions import bijectors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found this test very intuitive. Great work. One note: you don't use the bijectors
module
tests/util/test_marginal.py
Outdated
|
||
def test_sample_passthrough(self): | ||
with self.test_session(): | ||
loc = Normal(0.0, 100.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's with very low probability this can produce a false negative/positive, but in general you should always set seed in tests when checking randomness.
Creating a Marginal class crossed my mind, in particular when I thought how to implement the API as you described it, first creating the marginal RV and then being able to repeatedly sample it. My thoughts are that the best way to add that would be to add a marginal method to RandomVariable, which would perform the graph manipulation done here, but attach a placeholder as we discussed in #759. So it could be something like: rv.sample(10, deep=True)
# raises ValueError: 'must take marginal of rv before deep sampling'
marginal_rv = rv.marginal()
marginal_rv.sample(10)
# works as before, sets placeholder to one then samples 10 from the distribution
marginal_rv.sample(10, deep=True)
# desired functionality, sets placeholder to 10 then samples once from batch result It means more functionality is added to RV, but it makes sense that a function on an RV that returns an RV would be a class method. I'd favour the explicit marginal method since we're doing a copy. It also makes me a little less squeamish about all this private attribute access, even though its not on |
Re: #759 this adds limited support for full graph sampling. Given an RV,
ed.marginal
will traverse its parent graph, replacing any root ancestor instances ofRandomVariable
with a sampled equivalent, so that each non-root RV in the graph is evaluated with a tensor ofbatch_shape
of parameters.For example:
This current implementation does not work for graphs of RVs using the sample_shape arg. That will require some refactoring of how
RandomVariable
internally stores the sample_shape. I'm making this PR mostly because I'm confident that the API will be backwards compatible.Beyond not allowing
sample_shape
,ed.marginal
can fail in the following ways.ed.marginal
detecting incorrect broadcasting (this prevents situations where sampling10000
from a scalar RV produces some enthusiastically broadcasted(10000, 10000, 1)
shaped tensor.