Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters #680

Open
wants to merge 21 commits into
base: master
Choose a base branch
from

Conversation

clbonet
Copy link
Contributor

@clbonet clbonet commented Oct 19, 2024

Types of changes

This PR aims to add the Bures-Wasserstein gradient descent solver to compute Bures-Wasserstein barycenters (see e.g. Gradient descent algorithms for Bures-Wasserstein barycenters or Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent).

  • Restructured ot.gaussian.bures_wasserstein_barycenter to allow to use different methods
  • Added the previous fixed-point algorithm in ot.gaussian.bures_barycenter_fixpoint
  • Added the Bures-Wasserstein gradient descent in ot.gaussian.bures_barycenter_gradient_descent
  • Added an iteration over the methods in the test test_bures_wasserstein_barycenter
  • Added a test test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter
  • Added batch version of ot.gaussian.bures_wasserstein_distance
  • Trace can be computed for batchs of matrices

Motivation and context / Related issue

The Bures-Wasserstein gradient descent comes with convergence guarantees to solve Bures-Wasserstein barycenters. Moreover, it can also be used in a stochastic way when there are too much Gaussian. Thus, it is a good alternative to the fixed-point algorithm currently implemented.

How has this been tested (if it applies)

I added a test test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter to assess both methods returns the same barycenter. I also added the itertools to test_bures_wasserstein_barycenter.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comments. I will let @antoinecollas do a proper review he is the expert in Riemannian optimization

ot/utils.py Outdated Show resolved Hide resolved
ot/gaussian.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Oct 31, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.08%. Comparing base (6311e25) to head (d4045f1).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #680      +/-   ##
==========================================
+ Coverage   97.05%   97.08%   +0.03%     
==========================================
  Files          98       98              
  Lines       19877    20089     +212     
==========================================
+ Hits        19292    19504     +212     
  Misses        585      585              
---- 🚨 Try these New Features:

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. A few tests especialy about errors are missing

ot/gaussian.py Show resolved Hide resolved
ot/gaussian.py Show resolved Hide resolved
ot/gaussian.py Show resolved Hide resolved
ot/gaussian.py Outdated
# check convergence
if batch_size is not None and batch_size < n:
# TODO: criteria for SGD: on gradients? + test SGD
diff = nx.norm(Cb - Cnew)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not tested

ot/gaussian.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants