Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The distance between two of the same GMMs is not 0 #695

Open
GilgameshD opened this issue Nov 21, 2024 · 5 comments
Open

The distance between two of the same GMMs is not 0 #695

GilgameshD opened this issue Nov 21, 2024 · 5 comments
Assignees

Comments

@GilgameshD
Copy link

Describe the bug

The distance between two of the same GMMs is not 0. Sometimes the distance could be as large as 1e-3 when I use my own data. Is this because of the numerical problem?

To Reproduce

import numpy as np
import torch
import ot


if __name__ == "__main__":
    K = 10
    D = 300
    pi0 = np.random.rand(K)
    pi0 /= np.sum(pi0)
    mu0 = np.random.rand(K, D)
    S0 = np.eye(D)[None].repeat(K, axis=0)

    pi0 = torch.as_tensor(pi0, dtype=torch.float32)
    mu0 = torch.as_tensor(mu0, dtype=torch.float32)
    S0 = torch.as_tensor(S0, dtype=torch.float32)

    pi1 = pi0.clone()
    mu1 = mu0.clone()
    S1 = S0.clone()

    print((pi0 == pi1).all())
    print((mu0 == mu1).all())
    print((S0 == S1).all())

    dist = ot.gmm.gmm_ot_loss(mu0, mu1, S0, S1, pi0, pi1)
    print(dist)

The output distance of the above code is 1.2001e-05.

Expected behavior

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Ubuntu 22.04
  • Python version: 3.10
  • How was POT installed (source, pip, conda): source
@rflamary
Copy link
Collaborator

This is interesting. Could it come form the dist_bures_squared function that might not be exactly 0 on teh diagonal @eloitanguy ?

@eloitanguy
Copy link
Collaborator

that is interesting, i'll look into it.

@eloitanguy
Copy link
Collaborator

eloitanguy commented Nov 22, 2024

Hi, thanks for your Issue, I managed to reproduce it.
The issue stems from the fact that in this example (with np.random.seed(0)), as @rflamary suggested, torch.diag(ot.gmm.dist_bures_squared(mu0, mu1, S0, S1)) is not the zero vector as it should be, and it turns out that it is because ot.dist(mu0, mu1) has nonzero diagonal entries (10^(-5), as is coherent with the final GMM distance of roughly 10^(-5) instead of numerical 0).
If instead of torch.float32 you take torch.float64, the 10^(-5) diagonal entries in ot.dist(mu0, mu1) become 10^(-14) which is acceptable. It seems that is imprecision is somehow due to numerical imprecision in ot.dist when using torch.float32.

@GilgameshD
Copy link
Author

Thanks for identifying the problem! Is there any other solution rather than using torch.float64?

@eloitanguy
Copy link
Collaborator

Hi, I don't really have other ideas, but maybe @rflamary would know?
I know that ot.dist performs a check to verify if the data matrices are the same object, in which case it enforces the diagonal to be 0. This does not solve your issue, but it's closely related so I'm bringing it up anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants