Skip to content

Commit

Permalink
[MRG] Fix Gradient scaling in Partial GW solver (#602)
Browse files Browse the repository at this point in the history
* new file:   ot/partial_gw.py

* remove partial_gw.py to update existing file partial.py

* fix pep8

---------

Co-authored-by: Rémi Flamary <[email protected]>
Co-authored-by: Cédric Vincent-Cuaz <[email protected]>
  • Loading branch information
3 people authored Jun 21, 2024
1 parent 628a089 commit 14c08ba
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Fix same sign error for sr(F)GW conditional gradient solvers (PR #611)
- Split `test/test_gromov.py` into `test/gromov/` (PR #619)
- Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628)
- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602)

## 0.9.3
*January 2024*
Expand Down
15 changes: 9 additions & 6 deletions ot/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
"""

# Author: Laetitia Chapel <[email protected]>
# License: MIT License
# Yikun Bai < [email protected] >
# Cédric Vincent-Cuaz <[email protected]>

import numpy as np
from .lp import emd
from .backend import get_backend
from .utils import list_to_array
from .backend import get_backend
from .lp import emd
import numpy as np

# License: MIT License


def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
Expand Down Expand Up @@ -581,7 +584,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
" equal than min(|a|_1, |b|_1).")

if G0 is None:
G0 = np.outer(p, q)
G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q.

dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies)
q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies)
Expand All @@ -597,7 +600,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,

Gprev = np.copy(G0)

M = gwgrad_partial(C1, C2, G0)
M = 0.5 * gwgrad_partial(C1, C2, G0) # rescaling the gradient with 0.5 for line-search while not changing Gc
M_emd = np.zeros(dim_G_extended)
M_emd[:len(p), :len(q)] = M
M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2
Expand Down

0 comments on commit 14c08ba

Please sign in to comment.