-
Notifications
You must be signed in to change notification settings - Fork 502
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] Fix Gradient scaling in Partial GW solver (#602)
* new file: ot/partial_gw.py * remove partial_gw.py to update existing file partial.py * fix pep8 --------- Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Cédric Vincent-Cuaz <[email protected]>
- Loading branch information
1 parent
628a089
commit 14c08ba
Showing
2 changed files
with
10 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,15 @@ | |
""" | ||
|
||
# Author: Laetitia Chapel <[email protected]> | ||
# License: MIT License | ||
# Yikun Bai < [email protected] > | ||
# Cédric Vincent-Cuaz <[email protected]> | ||
|
||
import numpy as np | ||
from .lp import emd | ||
from .backend import get_backend | ||
from .utils import list_to_array | ||
from .backend import get_backend | ||
from .lp import emd | ||
import numpy as np | ||
|
||
# License: MIT License | ||
|
||
|
||
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, | ||
|
@@ -581,7 +584,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, | |
" equal than min(|a|_1, |b|_1).") | ||
|
||
if G0 is None: | ||
G0 = np.outer(p, q) | ||
G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. | ||
|
||
dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies) | ||
q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies) | ||
|
@@ -597,7 +600,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, | |
|
||
Gprev = np.copy(G0) | ||
|
||
M = gwgrad_partial(C1, C2, G0) | ||
M = 0.5 * gwgrad_partial(C1, C2, G0) # rescaling the gradient with 0.5 for line-search while not changing Gc | ||
M_emd = np.zeros(dim_G_extended) | ||
M_emd[:len(p), :len(q)] = M | ||
M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 | ||
|