Skip to content

Commit

Permalink
add exact line-search for (f)gw solvers with kl_loss (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz authored Nov 4, 2023
1 parent 53dde7a commit a73ad08
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 78 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
+ Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543)
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
78 changes: 48 additions & 30 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,17 @@ def df(G):

def df(G):
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
if loss_fun == 'kl_loss':
armijo = True # there is no closed form line-search with KL

# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True # there is no closed form line-search with KL

if armijo:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs)
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)
if log:
res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
Expand Down Expand Up @@ -296,9 +298,13 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
gw = nx.set_gradients(gw, (p, q, C1, C2),
(log_gw['u'] - nx.mean(log_gw['u']),
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

gw = nx.set_gradients(gw, (p, q, C1, C2),
(log_gw['u'] - nx.mean(log_gw['u']),
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))

if log:
return gw, log_gw
Expand Down Expand Up @@ -449,15 +455,16 @@ def df(G):
def df(G):
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))

if loss_fun == 'kl_loss':
armijo = True # there is no closed form line-search with KL
# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True # there is no closed form line-search with KL

if armijo:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
Expand Down Expand Up @@ -591,18 +598,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
if isinstance(alpha, int) or isinstance(alpha, float):
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:

fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T,
gw_term - lin_term))
elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
if isinstance(alpha, int) or isinstance(alpha, float):
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T,
gw_term - lin_term))

if log:
return fgw_dist, log_fgw
Expand All @@ -613,7 +622,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
alpha_min=None, alpha_max=None, nx=None, **kwargs):
"""
Solve the linesearch in the FW iterations
Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] <references-solve-linesearch>`.
Parameters
----------
Expand All @@ -625,9 +634,11 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
cost_G : float
Value of the cost at `G`
C1 : array-like (ns,ns), optional
Structure matrix in the source domain.
Transformed Structure matrix in the source domain.
For the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix
C2 : array-like (nt,nt), optional
Structure matrix in the target domain.
Transformed Structure matrix in the source domain.
For the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix
M : array-like (ns,nt)
Cost matrix between the features.
reg : float
Expand All @@ -649,11 +660,16 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
.. _references-solve-linesearch:
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
if nx is None:
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
Expand All @@ -664,8 +680,8 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
nx = get_backend(G, deltaG, C1, C2, M)

dot = nx.dot(nx.dot(C1, deltaG), C2.T)
a = -2 * reg * nx.sum(dot * deltaG)
b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
a = - reg * nx.sum(dot * deltaG)
b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))

alpha = solve_1d_linesearch_quad(a, b)
if alpha_min is not None or alpha_max is not None:
Expand Down Expand Up @@ -776,8 +792,9 @@ def gromov_barycenters(
else:
C = init_C

if loss_fun == 'kl_loss':
armijo = True
# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True

cpt = 0
err = 1
Expand Down Expand Up @@ -960,8 +977,9 @@ def fgw_barycenters(

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if loss_fun == 'kl_loss':
armijo = True
# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True

cpt = 0
err_feature = 1
Expand Down
97 changes: 49 additions & 48 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,31 @@ def test_gromov(nx):
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)

np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
for armijo in [False, True]:
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True)
gwb = nx.to_numpy(gwb)

gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False)
gw_valb = nx.to_numpy(
ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False)
)

gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])

gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
gw_valb = nx.to_numpy(
ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
)
np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)

G = log['T']
Gb = nx.to_numpy(logb['T'])
np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False

np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)

np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False

# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov


def test_asymmetric_gromov(nx):
Expand Down Expand Up @@ -1191,33 +1191,34 @@ def test_asymmetric_fgw(nx):
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)

# Tests with kl-loss:
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
Gb = nx.to_numpy(Gb)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)

fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)

G = log['T']
Gb = nx.to_numpy(logb['T'])
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
for armijo in [False, True]:
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, log=True, symmetric=False, verbose=True)
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, log=True, symmetric=None, G0=G0b, verbose=True)
Gb = nx.to_numpy(Gb)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)

fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)

G = log['T']
Gb = nx.to_numpy(logb['T'])
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)


def test_fgw2_gradients():
Expand Down

0 comments on commit a73ad08

Please sign in to comment.