Skip to content

Added ZeroSumNormal Distribution #4776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed

Conversation

kc611
Copy link
Contributor

@kc611 kc611 commented Jun 17, 2021

This PR adds a ZeroSumNormal distribution along with a zerosum transform. The original code for this is was written by @aseyboldt (probably).

This is a case of Normal distribution which constrains the sum of the means to be zero along the given axes.

@kc611 kc611 marked this pull request as draft June 17, 2021 06:51
@twiecki
Copy link
Member

twiecki commented Jun 17, 2021

This is great, thanks @kc611. Definitely needs a line in the release-notes.

@twiecki twiecki requested a review from aseyboldt June 18, 2021 08:59
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the PR!
I added a couple of comments.
We should also add some docs, and include a definition of what distribution this actually has (sigma **2 * (I - 1/n J), where I = eye(n) and J = ones(n).

I think we should have tests that use the zerosum_dims and zerosum_axes parameters.

I'm also not perfectly happy with those two names, can't think of anything better though. If anyone else has an idea...

Since I wrote the original code, could you add me as a co-author? (https://github.blog/2018-01-29-commit-together-with-co-authors/)

super().__init__(**kwargs)

def logp(self, value):
return Normal.dist(sigma=self.sigma / self._rescaling).logp(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be cleaner if we could define the logp on the transformed space. I don't think we implemented a way of doing this yet, did we?

@twiecki
Copy link
Member

twiecki commented Jun 18, 2021

CC @tomicapretto this will be super neat for Bambi where we don't have to define a base-condition and do relative coding for the others.

@ricardoV94 ricardoV94 added the v3 label Jun 18, 2021
@kc611 kc611 force-pushed the zerosumnormal branch 2 times, most recently from 94b8021 to acf359a Compare June 18, 2021 15:45
@kc611 kc611 closed this Jun 18, 2021
@kc611 kc611 reopened this Jun 18, 2021
@kc611 kc611 force-pushed the zerosumnormal branch 2 times, most recently from 86448c8 to 1b86b65 Compare June 18, 2021 15:54
@twiecki
Copy link
Member

twiecki commented Jul 5, 2021

@kc611 Any progress on this? seems like the tests are not passing yet.

@kc611
Copy link
Contributor Author

kc611 commented Jul 6, 2021

Sorry for late response, and yeah this is a shapes problem, the ZeroSumNormal is adding one extra dimension to the output shape. cc @aseyboldt Is this (extra dimension in result) intended (the case in which the tests would need a change) or do we need an np.squeeze somewhere ?

@AlexAndorra
Copy link
Contributor

Just an idea in passing: this distribution should also be added in v4 -- would be a shame to not have it in the new version, especially as it's very useful!
Note that I don't want to add to your to-dos @kc611. Do it if you feel like it, otherwise it can be done by someone else in another PR

super().__init__(**kwargs)

def logp(self, value):
return Normal.dist(sigma=self.sigma).logp(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don’t add the scaling of sigma here, our random method will be inconsistent with the logp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so ? Isn't random also drawing samples using self.sigma ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logp is somewhat strange still. From the math side this should be pm.MvNormal with cov=I - J / n where J is a matrix of all 1s. We don't want to write it like this though, because we don't want to do matrix factorization, and pm.MvNormal doesn't work if an eigenvalues is 0.
It would be great, if instead we could define the logp simply in the transformed space. This would imply changes to TransformedDistribution though.

Copy link
Contributor

@lucianopaz lucianopaz Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid the MvNormal.logp problem if we force self.sigma to be a scalar or to have a single element in all the zerosum_axes. In this case, all the directions in the zerosum manifold are uncorrelated and have equal variance. This means that we can use the Normal.logp as long as we also include a bound condition that guarantees that we are on the zerosum manifold. You can do this by using this logp:

Suggested change
return Normal.dist(sigma=self.sigma).logp(value)
zerosums = [tt.all(tt.abs_(tt.mean(x, axis=axis)) <= 1e-9) for axis in self.zerosum_axes]
return bound(
pm.Normal.dist(sigma=self.sigma).logp(x),
tt.all(self.sigma > 0),
broadcast_conditions=False,
*zerosums,
)

Copy link
Contributor

@lucianopaz lucianopaz Feb 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came across this wiki section that talks about the degenerate MvNormal case (which is what we have with the ZeroSumNormal). We could use that formula as the expected logp value and test if the logp that we are using in the distribution matches it. The expected logp would look something like this:

def pseudo_log_det(A, tol=1e-13):
    v, w = np.linalg.eigh(A)
    return np.sum(np.log(np.where(np.abs(v) >= tol, v, 1)), axis=-1)

def logp(value, sigma):
    n = value.shape[-1]
    cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
    psdet = 0.5 *  pseudo_log_det(2 * np.pi * cov)
    exp = 0.5 * (value[..., None, :] @ np.linalg.pinv(cov) @ value[..., None])[..., 0, 0]
    return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)

Copy link
Contributor

@lucianopaz lucianopaz Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a few tests with the logp and it looks like the logp that we are using in this PR, doesn't match what one would expect from a degenerate multivariate normal distribution. In my comment above, I posted what a degenerate MvNormal logp looks like. For this particular problem, where we know that we have only one eigenvector with zero eigenvalue, we can re-write the logp as:

def logp(value, sigma):
    n = value.shape[-1]
    cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
    v, w = np.linalg.eigh(cov)
    psdet =  0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
    cov_pinv = w[:, 1:] @ np.diag(1 / v[1:]) @ w[:, 1:].T
    exp = 0.5 * (value[..., None, :] @ cov_pinv @ value[..., None])[..., 0, 0]
    return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)

This is different from the logp that we are currently using in this PR. The difference is in the normalization constant:
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi)). In particular, since, all eigenvalues v except the first one are the same and are equal to sigma**2, psdet = (n - 1) * (0.5 * np.log(2 * np.pi) + np.log(np.sigma)). Whereas, with the assumed pm.Normal.dist(sigma=self.sigma).logp(x) the normalization factor we are getting is:

psdet = n * (0.5 * np.log(2 * np.pi) + np.log(np.sigma))

This means that we have to multiply the logp that we are using by (n-1)/n (in the case where only one axis sums to zero) to get the correct log probability density. I'll check what happens when more than one axes has to zerosum.

@twiecki
Copy link
Member

twiecki commented Aug 16, 2021

@kc611 Where are we on this?

@zoj613
Copy link
Contributor

zoj613 commented Aug 16, 2021

The name of the class suggests that the samples themselves are to sum to zero, not the means as explained in the PR description. Would a name like ZeroMeanSumNormal/MeanZeroSumNormal not be an alternative in this case?

@twiecki
Copy link
Member

twiecki commented Aug 16, 2021

Or just MeanZeroNormal?

Co-authored-by: Adrian Seyboldt <[email protected]>
@kc611
Copy link
Contributor Author

kc611 commented Aug 16, 2021

It seemed like a simple type filtering issue. But thanks for reminding. Anyway put it on testing right now.

What should we rename he distribution to ? ( For now I've let it be ZeroSumNormal )

@drbenvincent
Copy link

FYI this is being showcased in a PR to PyMC3-examples here pymc-devs/pymc-examples#210
In that example I noticed graphviz comes out like this
Screenshot 2021-08-16 at 18 34 01
Not mission critical, but it would be pretty neat if it displayed as a specific \beta \sim \text{ZeroSumNormal} distribution node 😎
Please do let me know if I've not got the attributions/credits right.

@twiecki twiecki marked this pull request as ready for review August 17, 2021 12:30
@@ -98,7 +98,7 @@ def __new__(cls, name, *args, **kwargs):
raise TypeError("observed needs to be data but got: {}".format(type(data)))
total_size = kwargs.pop("total_size", None)

dims = kwargs.pop("dims", None)
dims = kwargs["dims"] if "dims" in kwargs else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change to the pm.Distribution API. Right now you can define new distributions that do not accept a dims kwargs, but with this they have to accept it. Maybe we can live with this, but at the very least this should be mentioned in the release notes somewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted that change, as we're now using __new__ to enable the zerosum_dims kwarg

super().__init__(**kwargs)

def logp(self, value):
return Normal.dist(sigma=self.sigma).logp(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logp is somewhat strange still. From the math side this should be pm.MvNormal with cov=I - J / n where J is a matrix of all 1s. We don't want to write it like this though, because we don't want to do matrix factorization, and pm.MvNormal doesn't work if an eigenvalues is 0.
It would be great, if instead we could define the logp simply in the transformed space. This would imply changes to TransformedDistribution though.

@aseyboldt
Copy link
Member

@zoj613 The description is a bit strange with the "which constrains the sum of the means to be zero". It really constrains the sum (and equivalently mean) to zero.

@aseyboldt
Copy link
Member

@drbenvincent The strange graphviz output is because the other PR is using an old version of the ZeroSumNormal that is implemented as a deterministic.

@aseyboldt
Copy link
Member

Oh, and I think we have to check if the axes length is 1 or 0 and handles those cases correctly, or throw a good error.

@twiecki
Copy link
Member

twiecki commented Sep 16, 2021

@kc611 Any progress on this?


return super().__new__(cls, name, *args, **kwargs)

def __init__(self, sigma=1, zerosum_axes=None, zerosum_dims=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that zerosum_dims is not used in __init__, but if I don't put it here, it doesn't seem to be passed on to __new__: TypeError: __init__() got an unexpected keyword argument 'zerosum_dims'
Not sure we can do it otherwise though. If someone has a better idea, I'm all ears

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the zerosum_dims is probably still in kwargs from line 949? We could just remove there.

@AlexAndorra
Copy link
Contributor

AlexAndorra commented Oct 5, 2021

@aseyboldt and I have implemented the zerosum_dims kwarg (with the caveat I put in my message just above; not sure how big a deal it is.

IIUC, only two things are blocking merge for now:

  1. Adding a test of this zerosum_dims feature
  2. Deciding on the parametrization we wanna use for sigma

We'll be looking at those during a hackathon this Friday

@twiecki
Copy link
Member

twiecki commented Oct 5, 2021

Why even add this to v3 and not v4 directly?

@AlexAndorra
Copy link
Contributor

AlexAndorra commented Oct 5, 2021

The best reason I've found is sunk-cost fallacy 😅
TBH I agree: let's only add that in v4. That means no duplicated efforts and it nudges people towards v4

@aseyboldt
Copy link
Member

About the meaning of the (scalar) sigma parameter, and whether that should be the std of the underlying normal, or the std of the marginals (cc @lucianopaz):

I think we have basically two options about what sigma could refer to:

  • The marginal std, so pm.ZeroSumNormal(sigma=1).random(1000).std() == 1. We get C = sigma ** 2 * (I - I / n) / (1 - 1/n) as covariance.
  • The sqrt of the non-zero eigenvalues of the covariance matrix. This would imply pm.ZeroSumNormal(sigma=1).random().std() == sqrt(1 - 1 / n)). We get C = sigma**2 * (I - I / n) as covariance.

The obvious downside of the second option would be that it might be confusing to users, as the marginal std is not what people might expect. But I think there are reasons for the second option, and I think I'd prefer that one. It will be very confusing if we decide to change it later however, so we should give it some thought.

Let's assume we use the ZeroSumNormal in a regression (all sigma values in terms of the second option):

with pm.Model() as model:
    intercept = pm.Normal("intercept", sd=10)
    group_effect = pm.Normal("group_effect", sd=1, dims="group")
    
    mu = (
        intercept
        + group_effect[group_idx]
    )

We can rewrite this as this (without changing the model, only the parametrization is different:

with pm.Model() as model:
    intercept = pm.Normal("intercept", sd=10)

    mean_sigma = at.sqrt(1 ** 2 / len(groups))
    group_effect_mean = pm.Normal("group_effect_mean", sigma=mean_sigma)
    group_effect_diff = pm.ZeroSumNormal("group_effect_diff", sigma=1, dims="group")
    
    group_effect = pm.Deterministic("group_effect", group_effect_mean + group_effect_diff)
    
    mu = (
        intercept
        + group_effect_mean
        + group_effect_diff[group_idx]
    )

Now we use the fact that the sum of two normal dists is a normal again (the intercept + group_effect_mean):

with pm.Model() as model:
    mean_sigma = at.sqrt(1 ** 2 / len(groups))
    intercept_plus_mean = pm.Normal("intercept", sd=tt.sqrt(10 ** 2 + mean_sigma ** 2))

    group_effect_diff = pm.ZeroSumNormal("group_effect_diff", sigma=1, dims="group")

    mu = (
        intercept_plus_mean
        + group_effect_diff[group_idx]
    )

So we turned our model that assumed that group_effect ~ normal(0, 1) into a model using the zerosumnormal, but interestingly the sigma value is still 1, independent of the number of groups. If we had used the first option for sigma, we would have to set sigma to something that depends on the number of groups.

@AlexAndorra
Copy link
Contributor

Thanks for this detailed explanation @aseyboldt ! I remember that discussion from a late-night train commute in Lisbon 😅
So, IIUC:

  • The plus of the first option is that it's easy to understand and parametrize: the resulting ZeroSumNormal's marginals have std = 1. The downside is that to make sure of that, ZeroSumNormal's sigma's prior has to depend on the number of groups (which may in turn hurt users' understanding of the parametrization).
  • The plus of the second option is that the sigma prior doesn't depend on the number of groups, but the interpretation of the marginals (the sqrt of the non-zero eigenvalues of the covariance matrix) is... not self-evident, to put it gently.

Is my summary mistaken in any way?

@aseyboldt
Copy link
Member

I think that's a good summary :-)

A related question is what we would want to happen if sigma is not a scalar. Maybe the safest thing is to just not allow that for now.

@AlexAndorra
Copy link
Contributor

That would mean a different sigma for each group?

@twiecki
Copy link
Member

twiecki commented Jan 19, 2022

@AlexAndorra Any plans to revive this?

@AlexAndorra
Copy link
Contributor

Yeah yeah, it's still on my radar but this week is very busy work-wise. Hope to start next week. In any case, it's not gonna be a short PR

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm picking up some slack here because I started looking into porting the ZeroSumNormal into v4 and I found a small bug with prior predictive sampling.
We should add a test for the random method that asserts that we get allclose results to 0 on the axes that we want to sum to zero

if shape:
zerosum_axes = (-1,)
else:
zerosum_axes = ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes no sense to have a ZeroSumNormal when shape=() or None. In that case, the RV should also be exactly equal to zero. I think that we should test if shape is None or len(shape) == 0 and raise a ValueError in that case. Something that says, ZeroSumNormal is defined only for RVs that are not scalar.

if isinstance(zerosum_axes, int):
zerosum_axes = (zerosum_axes,)

self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforcing positive axis here leads to problems when you draw samples from the prior predictive. It's better to replace this line with this

Suggested change
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes]
self.zerosum_axes = [a if a < 0 else a - len(shape) for a in zerosum_axes]

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wanted to suggest how to change the logp

super().__init__(**kwargs)

def logp(self, value):
return Normal.dist(sigma=self.sigma).logp(value)
Copy link
Contributor

@lucianopaz lucianopaz Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid the MvNormal.logp problem if we force self.sigma to be a scalar or to have a single element in all the zerosum_axes. In this case, all the directions in the zerosum manifold are uncorrelated and have equal variance. This means that we can use the Normal.logp as long as we also include a bound condition that guarantees that we are on the zerosum manifold. You can do this by using this logp:

Suggested change
return Normal.dist(sigma=self.sigma).logp(value)
zerosums = [tt.all(tt.abs_(tt.mean(x, axis=axis)) <= 1e-9) for axis in self.zerosum_axes]
return bound(
pm.Normal.dist(sigma=self.sigma).logp(x),
tt.all(self.sigma > 0),
broadcast_conditions=False,
*zerosums,
)

@ricardoV94
Copy link
Member

With V4 closer and closer to being out, this PR makes less sense. I'll close it soon if nobody objects.

Possibly, we should add the ZeroSumNormal to pymc-experimental anyway.

@twiecki
Copy link
Member

twiecki commented Mar 18, 2022

Agreed.

@twiecki twiecki closed this Mar 18, 2022
@lucianopaz lucianopaz mentioned this pull request Sep 17, 2022
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants