-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Added ZeroSumNormal Distribution #4776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is great, thanks @kc611. Definitely needs a line in the release-notes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the PR!
I added a couple of comments.
We should also add some docs, and include a definition of what distribution this actually has (sigma **2 * (I - 1/n J), where I = eye(n) and J = ones(n).
I think we should have tests that use the zerosum_dims and zerosum_axes parameters.
I'm also not perfectly happy with those two names, can't think of anything better though. If anyone else has an idea...
Since I wrote the original code, could you add me as a co-author? (https://github.blog/2018-01-29-commit-together-with-co-authors/)
pymc3/distributions/continuous.py
Outdated
super().__init__(**kwargs) | ||
|
||
def logp(self, value): | ||
return Normal.dist(sigma=self.sigma / self._rescaling).logp(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be cleaner if we could define the logp on the transformed space. I don't think we implemented a way of doing this yet, did we?
CC @tomicapretto this will be super neat for Bambi where we don't have to define a base-condition and do relative coding for the others. |
94b8021
to
acf359a
Compare
86448c8
to
1b86b65
Compare
@kc611 Any progress on this? seems like the tests are not passing yet. |
Sorry for late response, and yeah this is a shapes problem, the |
Just an idea in passing: this distribution should also be added in v4 -- would be a shame to not have it in the new version, especially as it's very useful! |
super().__init__(**kwargs) | ||
|
||
def logp(self, value): | ||
return Normal.dist(sigma=self.sigma).logp(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don’t add the scaling of sigma here, our random
method will be inconsistent with the logp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How so ? Isn't random
also drawing samples using self.sigma
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logp is somewhat strange still. From the math side this should be pm.MvNormal with cov=I - J / n where J is a matrix of all 1s. We don't want to write it like this though, because we don't want to do matrix factorization, and pm.MvNormal doesn't work if an eigenvalues is 0.
It would be great, if instead we could define the logp simply in the transformed space. This would imply changes to TransformedDistribution though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can avoid the MvNormal.logp
problem if we force self.sigma
to be a scalar or to have a single element in all the zerosum_axes
. In this case, all the directions in the zerosum manifold are uncorrelated and have equal variance. This means that we can use the Normal.logp
as long as we also include a bound
condition that guarantees that we are on the zerosum manifold. You can do this by using this logp:
return Normal.dist(sigma=self.sigma).logp(value) | |
zerosums = [tt.all(tt.abs_(tt.mean(x, axis=axis)) <= 1e-9) for axis in self.zerosum_axes] | |
return bound( | |
pm.Normal.dist(sigma=self.sigma).logp(x), | |
tt.all(self.sigma > 0), | |
broadcast_conditions=False, | |
*zerosums, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came across this wiki section that talks about the degenerate MvNormal case (which is what we have with the ZeroSumNormal). We could use that formula as the expected logp value and test if the logp
that we are using in the distribution matches it. The expected logp would look something like this:
def pseudo_log_det(A, tol=1e-13):
v, w = np.linalg.eigh(A)
return np.sum(np.log(np.where(np.abs(v) >= tol, v, 1)), axis=-1)
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
psdet = 0.5 * pseudo_log_det(2 * np.pi * cov)
exp = 0.5 * (value[..., None, :] @ np.linalg.pinv(cov) @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran a few tests with the logp and it looks like the logp
that we are using in this PR, doesn't match what one would expect from a degenerate multivariate normal distribution. In my comment above, I posted what a degenerate MvNormal logp
looks like. For this particular problem, where we know that we have only one eigenvector with zero eigenvalue, we can re-write the logp as:
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
v, w = np.linalg.eigh(cov)
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
cov_pinv = w[:, 1:] @ np.diag(1 / v[1:]) @ w[:, 1:].T
exp = 0.5 * (value[..., None, :] @ cov_pinv @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)
This is different from the logp
that we are currently using in this PR. The difference is in the normalization constant:
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
. In particular, since, all eigenvalues v
except the first one are the same and are equal to sigma**2
, psdet = (n - 1) * (0.5 * np.log(2 * np.pi) + np.log(np.sigma))
. Whereas, with the assumed pm.Normal.dist(sigma=self.sigma).logp(x)
the normalization factor we are getting is:
psdet = n * (0.5 * np.log(2 * np.pi) + np.log(np.sigma))
This means that we have to multiply the logp
that we are using by (n-1)/n (in the case where only one axis sums to zero) to get the correct log probability density. I'll check what happens when more than one axes has to zerosum.
@kc611 Where are we on this? |
The name of the class suggests that the samples themselves are to sum to zero, not the means as explained in the PR description. Would a name like |
Or just |
Co-authored-by: Adrian Seyboldt <[email protected]>
It seemed like a simple type filtering issue. But thanks for reminding. Anyway put it on testing right now. What should we rename he distribution to ? ( For now I've let it be |
FYI this is being showcased in a PR to |
pymc3/distributions/distribution.py
Outdated
@@ -98,7 +98,7 @@ def __new__(cls, name, *args, **kwargs): | |||
raise TypeError("observed needs to be data but got: {}".format(type(data))) | |||
total_size = kwargs.pop("total_size", None) | |||
|
|||
dims = kwargs.pop("dims", None) | |||
dims = kwargs["dims"] if "dims" in kwargs else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a breaking change to the pm.Distribution
API. Right now you can define new distributions that do not accept a dims kwargs, but with this they have to accept it. Maybe we can live with this, but at the very least this should be mentioned in the release notes somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted that change, as we're now using __new__
to enable the zerosum_dims
kwarg
super().__init__(**kwargs) | ||
|
||
def logp(self, value): | ||
return Normal.dist(sigma=self.sigma).logp(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logp is somewhat strange still. From the math side this should be pm.MvNormal with cov=I - J / n where J is a matrix of all 1s. We don't want to write it like this though, because we don't want to do matrix factorization, and pm.MvNormal doesn't work if an eigenvalues is 0.
It would be great, if instead we could define the logp simply in the transformed space. This would imply changes to TransformedDistribution though.
@zoj613 The description is a bit strange with the "which constrains the sum of the means to be zero". It really constrains the sum (and equivalently mean) to zero. |
@drbenvincent The strange graphviz output is because the other PR is using an old version of the ZeroSumNormal that is implemented as a deterministic. |
Oh, and I think we have to check if the axes length is 1 or 0 and handles those cases correctly, or throw a good error. |
@kc611 Any progress on this? |
|
||
return super().__new__(cls, name, *args, **kwargs) | ||
|
||
def __init__(self, sigma=1, zerosum_axes=None, zerosum_dims=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that zerosum_dims
is not used in __init__
, but if I don't put it here, it doesn't seem to be passed on to __new__
: TypeError: __init__() got an unexpected keyword argument 'zerosum_dims'
Not sure we can do it otherwise though. If someone has a better idea, I'm all ears
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the zerosum_dims is probably still in kwargs from line 949? We could just remove there.
@aseyboldt and I have implemented the IIUC, only two things are blocking merge for now:
We'll be looking at those during a hackathon this Friday |
Why even add this to v3 and not v4 directly? |
The best reason I've found is sunk-cost fallacy 😅 |
About the meaning of the (scalar) I think we have basically two options about what
The obvious downside of the second option would be that it might be confusing to users, as the marginal std is not what people might expect. But I think there are reasons for the second option, and I think I'd prefer that one. It will be very confusing if we decide to change it later however, so we should give it some thought. Let's assume we use the ZeroSumNormal in a regression (all sigma values in terms of the second option): with pm.Model() as model:
intercept = pm.Normal("intercept", sd=10)
group_effect = pm.Normal("group_effect", sd=1, dims="group")
mu = (
intercept
+ group_effect[group_idx]
) We can rewrite this as this (without changing the model, only the parametrization is different: with pm.Model() as model:
intercept = pm.Normal("intercept", sd=10)
mean_sigma = at.sqrt(1 ** 2 / len(groups))
group_effect_mean = pm.Normal("group_effect_mean", sigma=mean_sigma)
group_effect_diff = pm.ZeroSumNormal("group_effect_diff", sigma=1, dims="group")
group_effect = pm.Deterministic("group_effect", group_effect_mean + group_effect_diff)
mu = (
intercept
+ group_effect_mean
+ group_effect_diff[group_idx]
) Now we use the fact that the sum of two normal dists is a normal again (the intercept + group_effect_mean): with pm.Model() as model:
mean_sigma = at.sqrt(1 ** 2 / len(groups))
intercept_plus_mean = pm.Normal("intercept", sd=tt.sqrt(10 ** 2 + mean_sigma ** 2))
group_effect_diff = pm.ZeroSumNormal("group_effect_diff", sigma=1, dims="group")
mu = (
intercept_plus_mean
+ group_effect_diff[group_idx]
) So we turned our model that assumed that |
Thanks for this detailed explanation @aseyboldt ! I remember that discussion from a late-night train commute in Lisbon 😅
Is my summary mistaken in any way? |
I think that's a good summary :-) A related question is what we would want to happen if sigma is not a scalar. Maybe the safest thing is to just not allow that for now. |
That would mean a different sigma for each group? |
@AlexAndorra Any plans to revive this? |
Yeah yeah, it's still on my radar but this week is very busy work-wise. Hope to start next week. In any case, it's not gonna be a short PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm picking up some slack here because I started looking into porting the ZeroSumNormal into v4 and I found a small bug with prior predictive sampling.
We should add a test for the random method that asserts that we get allclose
results to 0 on the axes that we want to sum to zero
if shape: | ||
zerosum_axes = (-1,) | ||
else: | ||
zerosum_axes = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes no sense to have a ZeroSumNormal
when shape=()
or None
. In that case, the RV should also be exactly equal to zero. I think that we should test if shape is None or len(shape) == 0
and raise a ValueError
in that case. Something that says, ZeroSumNormal
is defined only for RVs that are not scalar.
if isinstance(zerosum_axes, int): | ||
zerosum_axes = (zerosum_axes,) | ||
|
||
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforcing positive axis here leads to problems when you draw samples from the prior predictive. It's better to replace this line with this
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] | |
self.zerosum_axes = [a if a < 0 else a - len(shape) for a in zerosum_axes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wanted to suggest how to change the logp
super().__init__(**kwargs) | ||
|
||
def logp(self, value): | ||
return Normal.dist(sigma=self.sigma).logp(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can avoid the MvNormal.logp
problem if we force self.sigma
to be a scalar or to have a single element in all the zerosum_axes
. In this case, all the directions in the zerosum manifold are uncorrelated and have equal variance. This means that we can use the Normal.logp
as long as we also include a bound
condition that guarantees that we are on the zerosum manifold. You can do this by using this logp:
return Normal.dist(sigma=self.sigma).logp(value) | |
zerosums = [tt.all(tt.abs_(tt.mean(x, axis=axis)) <= 1e-9) for axis in self.zerosum_axes] | |
return bound( | |
pm.Normal.dist(sigma=self.sigma).logp(x), | |
tt.all(self.sigma > 0), | |
broadcast_conditions=False, | |
*zerosums, | |
) |
With V4 closer and closer to being out, this PR makes less sense. I'll close it soon if nobody objects. Possibly, we should add the ZeroSumNormal to pymc-experimental anyway. |
Agreed. |
This PR adds a
ZeroSumNormal
distribution along with azerosum
transform. The original code for this is was written by @aseyboldt (probably).This is a case of
Normal
distribution which constrains the sum of the means to be zero along the given axes.