From a73ad08954b384c6dc63eb99c10f0489f0e89b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sat, 4 Nov 2023 20:57:04 +0100 Subject: [PATCH 1/4] add exact line-search for (f)gw solvers with kl_loss (#556) --- RELEASES.md | 1 + ot/gromov/_gw.py | 78 ++++++++++++++++++++++-------------- test/test_gromov.py | 97 +++++++++++++++++++++++---------------------- 3 files changed, 98 insertions(+), 78 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index f943886d5..7c090bef8 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -13,6 +13,7 @@ + Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543) + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) ++ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index d5e4c7f13..88b1eb75f 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -160,15 +160,17 @@ def df(G): def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) - if loss_fun == 'kl_loss': - armijo = True # there is no closed form line-search with KL + + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True # there is no closed form line-search with KL if armijo: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs) if log: res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) @@ -296,9 +298,13 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - gw = nx.set_gradients(gw, (p, q, C1, C2), - (log_gw['u'] - nx.mean(log_gw['u']), - log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + gw = nx.set_gradients(gw, (p, q, C1, C2), + (log_gw['u'] - nx.mean(log_gw['u']), + log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) if log: return gw, log_gw @@ -449,15 +455,16 @@ def df(G): def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) - if loss_fun == 'kl_loss': - armijo = True # there is no closed form line-search with KL + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True # there is no closed form line-search with KL if armijo: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) @@ -591,18 +598,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - if isinstance(alpha, int) or isinstance(alpha, float): - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T)) - else: - - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T, - gw_term - lin_term)) + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + if isinstance(alpha, int) or isinstance(alpha, float): + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T)) + else: + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T, + gw_term - lin_term)) if log: return fgw_dist, log_fgw @@ -613,7 +622,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): """ - Solve the linesearch in the FW iterations + Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] `. Parameters ---------- @@ -625,9 +634,11 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, cost_G : float Value of the cost at `G` C1 : array-like (ns,ns), optional - Structure matrix in the source domain. + Transformed Structure matrix in the source domain. + For the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix C2 : array-like (nt,nt), optional - Structure matrix in the target domain. + Transformed Structure matrix in the source domain. + For the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix M : array-like (ns,nt) Cost matrix between the features. reg : float @@ -649,11 +660,16 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, .. _references-solve-linesearch: + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + """ if nx is None: G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) @@ -664,8 +680,8 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, nx = get_backend(G, deltaG, C1, C2, M) dot = nx.dot(nx.dot(C1, deltaG), C2.T) - a = -2 * reg * nx.sum(dot * deltaG) - b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + a = - reg * nx.sum(dot * deltaG) + b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: @@ -776,8 +792,9 @@ def gromov_barycenters( else: C = init_C - if loss_fun == 'kl_loss': - armijo = True + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True cpt = 0 err = 1 @@ -960,8 +977,9 @@ def fgw_barycenters( Ms = [dist(X, Ys[s]) for s in range(len(Ys))] - if loss_fun == 'kl_loss': - armijo = True + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True cpt = 0 err_feature = 1 diff --git a/test/test_gromov.py b/test/test_gromov.py index 06f843a4a..8870a5023 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -52,31 +52,31 @@ def test_gromov(nx): Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) + for armijo in [False, True]: + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True) + gwb = nx.to_numpy(gwb) + + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False) + gw_valb = nx.to_numpy( + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False) + ) - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True) - gwb = nx.to_numpy(gwb) + G = log['T'] + Gb = nx.to_numpy(logb['T']) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False) - gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) - ) + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) + np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False - np.testing.assert_allclose(gw, gwb, atol=1e-06) - np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - - np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) - np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov def test_asymmetric_gromov(nx): @@ -1191,33 +1191,34 @@ def test_asymmetric_fgw(nx): np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) # Tests with kl-loss: - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True) - Gb = nx.to_numpy(Gb) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) - - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + for armijo in [False, True]: + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) def test_fgw2_gradients(): From 10717598a7e991e02d9fbc30d3a05b852916ea2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sat, 4 Nov 2023 23:07:34 +0100 Subject: [PATCH 2/4] add kl_loss to all semi-relaxed (f)gw solvers (#559) --- RELEASES.md | 1 + ot/gromov/_semirelaxed.py | 80 ++++----- ot/gromov/_utils.py | 27 +++- test/test_gromov.py | 332 +++++++++++++++++++------------------- 4 files changed, 230 insertions(+), 210 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 7c090bef8..cdc986624 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) + Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) ++ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 0b905c1fa..cbfe64ea8 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -56,7 +56,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -92,8 +91,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() arr = [C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -139,7 +136,7 @@ def df(G): return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs) + return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs) if log: res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) @@ -190,7 +187,6 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -243,7 +239,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) + + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) if log: return srgw, log_srgw @@ -291,7 +292,6 @@ def semirelaxed_fused_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -332,9 +332,6 @@ def semirelaxed_fused_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() - arr = [M, C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -382,7 +379,7 @@ def df(G): def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs) + G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs) if log: res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) @@ -434,7 +431,6 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. @@ -494,15 +490,20 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - if isinstance(alpha, int) or isinstance(alpha, float): - srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), - (alpha * gC1, alpha * gC2, (1 - alpha) * T)) - else: - lin_term = nx.sum(T * M) - srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha - srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), - (alpha * gC1, alpha * gC2, (1 - alpha) * T, - srgw_term - lin_term)) + + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + if isinstance(alpha, int) or isinstance(alpha, float): + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), + (alpha * gC1, alpha * gC2, (1 - alpha) * T)) + else: + lin_term = nx.sum(T * M) + srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), + (alpha * gC1, alpha * gC2, (1 - alpha) * T, + srgw_term - lin_term)) if log: return srfgw_dist, log_fgw @@ -511,7 +512,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, - M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): + M, reg, fC2t=None, alpha_min=None, alpha_max=None, nx=None, **kwargs): """ Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence. @@ -524,16 +525,22 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration cost_G : float Value of the cost at `G` - C1 : array-like (ns,ns) - Structure matrix in the source domain. - C2 : array-like (nt,nt) - Structure matrix in the target domain. + C1 : array-like (ns,ns), optional + Transformed Structure matrix in the source domain. + Note that for the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix_semirelaxed + C2 : array-like (nt,nt), optional + Transformed Structure matrix in the source domain. + Note that for the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix_semirelaxed ones_p: array-like (ns,1) Array of ones of size ns M : array-like (ns,nt) Cost matrix between the features. reg : float Regularization parameter. + fC2t: array-like (nt,nt), optional + Transformed Structure matrix in the source domain. + Note that for the 'square_loss' and 'kl_loss', we provide fC2t from ot.gromov.init_matrix_semirelaxed. + If fC2t is not provided, it is by default fC2t corresponding to the 'square_loss'. alpha_min : float, optional Minimum value for alpha alpha_max : float, optional @@ -565,11 +572,14 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0) dot = nx.dot(nx.dot(C1, deltaG), C2.T) - C2t_square = C2.T ** 2 - dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square) - dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square) - a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG) - b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) + if fC2t is None: + fC2t = C2.T ** 2 + dot_qG = nx.dot(nx.outer(ones_p, qG), fC2t) + dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), fC2t) + + a = reg * nx.sum((dot_qdeltaG - dot) * deltaG) + b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - dot) * G) + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) + alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) @@ -620,7 +630,6 @@ def entropic_semirelaxed_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional @@ -655,8 +664,6 @@ def entropic_semirelaxed_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() arr = [C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -777,7 +784,6 @@ def entropic_semirelaxed_gromov_wasserstein2( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional @@ -869,7 +875,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional @@ -907,8 +912,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun == 'kl_loss': - raise NotImplementedError() arr = [M, C1, C2] if p is not None: arr.append(list_to_array(p)) @@ -1032,7 +1035,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. - 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index d77e44f9e..2c1bda823 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -399,6 +399,19 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): h_2(b) &= 2b + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) Parameters ---------- C1 : array-like, shape (ns, ns) @@ -451,9 +464,19 @@ def h1(a): def h2(b): return 2 * b elif loss_fun == 'kl_loss': - raise NotImplementedError() + def f1(a): + return a * nx.log(a + 1e-15) - a + + def f2(b): + return b + + def h1(a): + return a + + def h2(b): + return nx.log(b + 1e-15) else: - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.") + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))), nx.ones((1, C2.shape[0]), type_as=p)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 8870a5023..a71433bb5 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -1941,32 +1941,33 @@ def test_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, - G0=None, alpha_min=0., alpha_max=1.) + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, + G0=None, alpha_min=0., alpha_max=1.) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) # symmetric C1 = 0.5 * (C1 + C1.T) @@ -2025,19 +2026,20 @@ def test_semirelaxed_gromov2_gradients(): if torch.cuda.is_available(): devices.append(torch.device("cuda")) for device in devices: - # semirelaxed solvers do not support gradients over masses yet. - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) + for loss_fun in ['square_loss', 'kl_loss']: + # semirelaxed solvers do not support gradients over masses yet. + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1) + val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1, loss_fun=loss_fun) - val.backward() + val.backward() - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape def test_srgw_helper_backend(nx): @@ -2057,35 +2059,35 @@ def test_srgw_helper_backend(nx): C1 /= C1.max() C2 /= C2.max() - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True) - - # calls with nx=None - constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') - ones_pb = nx.ones(pb.shape[0], type_as=pb) - - def f(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def df(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def line_search(cost, G, deltaG, Mi, cost_G): - return ot.gromov.solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None) - # feed the precomputed local optimum Gb to semirelaxed_cg - res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - # check constraints - np.testing.assert_allclose(res, Gb, atol=1e-06) + for loss_fun in ['square_loss', 'kl_loss']: + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun, armijo=False, symmetric=True, G0=None, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun) + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0., 1., fC2t=fC2tb, nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) @pytest.mark.parametrize('loss_fun', [ - 'square_loss', - pytest.param('kl_loss', marks=pytest.mark.xfail(raises=NotImplementedError)), + 'square_loss', 'kl_loss', pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), ]) def test_gw_semirelaxed_helper_validation(loss_fun): @@ -2149,32 +2151,33 @@ def test_semirelaxed_fgw(nx): np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) # symmetric - C1 = 0.5 * (C1 + C1.T) - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + for loss_fun in ['square_loss', 'kl_loss']: + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0b) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0) + srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) def test_semirelaxed_fgw2_gradients(): @@ -2203,37 +2206,38 @@ def test_semirelaxed_fgw2_gradients(): devices.append(torch.device("cuda")) for device in devices: # semirelaxed solvers do not support gradients over masses yet. - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) + for loss_fun in ['square_loss', 'kl_loss']: + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1) + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun) - val.backward() + val.backward() - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape - # full gradients with alpha - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) - alpha = torch.tensor(0.5, requires_grad=True, device=device) + # full gradients with alpha + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + alpha = torch.tensor(0.5, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, alpha=alpha) + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun, alpha=alpha) - val.backward() + val.backward() - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert alpha.shape == alpha.grad.shape + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert alpha.shape == alpha.grad.shape def test_srfgw_helper_backend(nx): @@ -2309,27 +2313,28 @@ def test_entropic_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) epsilon = 0.1 - G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=None) + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=None) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=None) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) # symmetric C1 = 0.5 * (C1 + C1.T) @@ -2382,19 +2387,20 @@ def test_entropic_semirelaxed_gromov_dtype_device(nx): C2 /= C2.max() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) + print(nx.dtype_device(tp)) + for loss_fun in ['square_loss', 'kl_loss']: + C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) - Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( - C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) - gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) def test_entropic_semirelaxed_fgw(nx): @@ -2450,29 +2456,30 @@ def test_entropic_semirelaxed_fgw(nx): C1 = 0.5 * (C1 + C1.T) Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) + srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) @pytest.skip_backend("tf", reason="test very slow with tf backend") @@ -2505,15 +2512,16 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx): Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp) - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) - fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True - ) + for loss_fun in ['square_loss', 'kl_loss']: + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, fgw_valb) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) def test_not_implemented_solver(): @@ -2546,17 +2554,3 @@ def test_not_implemented_solver(): with pytest.raises(ValueError): ot.gromov.entropic_fused_gromov_wasserstein( M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) - - # exact and entropic srgw and srfgw loss functions - loss_fun = 'kl_loss' - with pytest.raises(NotImplementedError): - ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, armijo=False) - with pytest.raises(NotImplementedError): - ot.gromov.entropic_semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, epsilon=0.1) - with pytest.raises(NotImplementedError): - ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun) - with pytest.raises(NotImplementedError): - ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun, epsilon=0.1) From fe20bc6f5d2051e57ada85644e4f8303fbf46bdf Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Mon, 6 Nov 2023 10:44:09 +0100 Subject: [PATCH 3/4] [WIP] Add new features to unbalanced solvers (#551) * add new features to unbalanced solvers * add new features to unbalanced solvers * fix bug in test * remove stab_sinkhorn * remove kl * fix bug in lbfgsb_unbalanced * fix bug in lbfgsb_unbalanced * fix bug in KL in sinkhorn_unbalanced * edit release.md * fix test * add test and rearrange arguments * fix test * fix test * fix test * fix bug in test * fix bug in doctest * fix bug in doctest * add test for more coverage --- RELEASES.md | 163 ++------------------- ot/unbalanced.py | 303 +++++++++++++++++++++++++--------------- test/test_unbalanced.py | 202 ++++++++++++++++++++++++--- 3 files changed, 386 insertions(+), 282 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index cdc986624..97834cdfe 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -13,8 +13,7 @@ + Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543) + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) -+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) -+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) ++ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) @@ -36,143 +35,13 @@ Many other bugs and issues have been fixed and we want to thank all the contribu #### New features -- Gaussian Gromov Wasserstein loss and mapping (PR #498) -- Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488) -- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483) -- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) -- Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459) -- Add tests on GPU for master branch and approved PR (PR #473) -- Add `median` method to all inherited classes of `backend.Backend` (PR #472) -- Update tests for macOS and Windows, speedup documentation (PR #484) -- Added Proximal Point algorithm to solve GW problems via a new parameter `solver="PPA"` in `ot.gromov.entropic_gromov_wasserstein` + examples (PR #455) -- Added features `warmstart` and `kwargs` in `ot.gromov.entropic_gromov_wasserstein` to respectively perform warmstart on dual potentials and pass parameters to `ot.sinkhorn` (PR #455) -- Added sinkhorn projection based solvers for FGW `ot.gromov.entropic_fused_gromov_wasserstein` and entropic FGW barycenters + examples (PR #455) -- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455) -- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455) -- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455) -- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486) -- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454) -#### Closed issues - -- Fix gromov conventions (PR #497) -- Fix change in scipy API for `cdist` (PR #487) -- More permissive check_backend (PR #494) -- Fix circleci-redirector action and codecov (PR #460) -- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457) -- Major documentation cleanup (PR #462, PR #467, PR #475) -- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466) -- Faster Bures-Wasserstein distance with NumPy backend (PR #468) -- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471) -- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (PR #474) -- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472) -- Fix pression error on marginal sums and (Issue #429, PR #496) - -#### New Contributors -* @kachayev made their first contribution in PR #462 -* @liutianlin0121 made their first contribution in PR #459 -* @francois-rozet made their first contribution in PR #468 -* @framunoz made their first contribution in PR #472 -* @SoniaMaz8 made their first contribution in PR #483 -* @tomMoral made their first contribution in PR #494 -* @12hengyu made their first contribution in PR #454 - -## 0.9.0 -*April 2023* - -This new release contains so many new features and bug fixes since 0.8.2 that we -decided to make it a new minor release at 0.9.0. - -The release contains many new features. First we did a major -update of all Gromov-Wasserstein solvers that brings up to 30% gain in -computation time (see PR #431) and allows the GW solvers to work on non symmetric -matrices. It also brings novel solvers for the very -efficient [semi-relaxed GW problem -](https://pythonot.github.io/master/auto_examples/gromov/plot_semirelaxed_fgw.html#sphx-glr-auto-examples-gromov-plot-semirelaxed-fgw-py) -that can be used to find the best re-weighting for one of the distributions. We -also now have fast and differentiable solvers for [Wasserstein on the circle](https://pythonot.github.io/master/auto_examples/plot_compute_wasserstein_circle.html#sphx-glr-auto-examples-plot-compute-wasserstein-circle-py) and -[sliced Wasserstein on the -sphere](https://pythonot.github.io/master/auto_examples/backends/plot_ssw_unif_torch.html#sphx-glr-auto-examples-backends-plot-ssw-unif-torch-py). -We are also very happy to provide new OT barycenter solvers such as the [Free -support Sinkhorn -barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_free_support_sinkhorn_barycenter.html#sphx-glr-auto-examples-barycenters-plot-free-support-sinkhorn-barycenter-py) -and the [Generalized Wasserstein -barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_generalized_free_support_barycenter.html#sphx-glr-auto-examples-barycenters-plot-generalized-free-support-barycenter-py). -A new differentiable solver for OT across spaces that provides OT plans -between samples and features simultaneously and -called [Co-Optimal -Transport](https://pythonot.github.io/master/auto_examples/others/plot_COOT.html) -has also been implemented. Finally we began working on OT between Gaussian distributions and -now provide differentiable estimation for the Bures-Wasserstein [divergence](https://pythonot.github.io/master/gen_modules/ot.gaussian.html#ot.gaussian.bures_wasserstein_distance) and -[mappings](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-linear-mapping-py). - -Another important first step toward POT 1.0 is the -implementation of a unified API for OT solvers with introduction of [`ot.solve`](https://pythonot.github.io/master/all.html#ot.solve) -function that can solve (depending on parameters) exact, regularized and -unbalanced OT and return a new -[`OTResult`](https://pythonot.github.io/master/gen_modules/ot.utils.html#ot.utils.OTResult) -object. The idea behind this new API is to facilitate exploring different solvers -with just a change of parameter and get a more unified API for them. We will keep -the old solvers API for power users but it will be the preferred way to solve -problems starting from release 1.0.0. -We provide below some examples of use for the new function and how to -recover different aspects of the solution (OT plan, full loss, linear part of the -loss, dual variables) : -```python -#Solve exact ot -sol = ot.solve(M) - -# get the results -G = sol.plan # OT plan -ot_loss = sol.value # OT value (full loss for regularized and unbalanced) -ot_loss_linear = sol.value_linear # OT value for linear term np.sum(sol.plan*M) -alpha, beta = sol.potentials # dual potentials - -# direct plan and loss computation -G = ot.solve(M).plan -ot_loss = ot.solve(M).value - -# OT exact with marginals a/b -sol2 = ot.solve(M, a, b) - -# regularized and unbalanced OT -sol_rkl = ot.solve(M, a, b, reg=1) # KL regularization -sol_rl2 = ot.solve(M, a, b, reg=1, reg_type='L2') -sol_ul2 = ot.solve(M, a, b, unbalanced=10, unbalanced_type='L2') -sol_rkl_ukl = ot.solve(M, a, b, reg=10, unbalanced=10) # KL + KL - -``` -The function is fully compatible with backends and will be implemented for -different types of distribution support (empirical distributions, grids) and OT -problems (Gromov-Wasserstein) in the new releases. This new API is not yet -presented in the kickstart part of the documentation as there is a small change -that it might change -when implementing new solvers but we encourage users to play with it. - -Finally, in addition to those many new this release fixes 20 issues (some long -standing) and we want to thank all the contributors who made this release so -big. More details below. - - -#### New features -- Added feature to (Fused) Gromov-Wasserstein solvers inherited from `ot.optim` to support relative and absolute loss variations as stopping criterions (PR #431) -- Added feature to (Fused) Gromov-Wasserstein solvers to handle asymmetric matrices (PR #431) -- Added semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #431) -- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) -- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) -- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) -- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434) - Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) - New API for OT solver using function `ot.solve` (PR #388) -- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449) -- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437) -- Added parameters method in `ot.da.SinkhornTransport` (PR #440) -- `ot.dr` now uses the new Pymanopt API and POT is compatible with current - Pymanopt (PR #443) -- Added CO-Optimal Transport solver + examples (PR #447) -- Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448) +- Backend version of `ot.partial` and `ot.smooth` (PR #388) +- Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #437) #### Closed issues @@ -200,11 +69,9 @@ PR #413) - Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) -- Fixed an issue with the documentation gallery section (PR #444) -- Fixed issues with cuda variables for `line_search_armijo` and `entropic_gromov_wasserstein` (Issue #445, #PR 446) + ## 0.8.2 -*April 2022* This releases introduces several new notable features. The less important but most exiting one being that we now have a logo for the toolbox (color @@ -348,7 +215,7 @@ a [Generative Network (GAN)](https://PythonOT.github.io/auto_examples/backends/plot_wass2_gan_torch.html), for a [sliced Wasserstein gradient flow](https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html) -and [optimizing the Gromov-Wasserstein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite +and [optimizing the Gromov-Wassersein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite slow at the moment, we strongly recommend for Jax users to use the [OTT toolbox](https://github.com/google-research/ott) when possible. As a result of this new feature, @@ -360,7 +227,7 @@ Pointwise Gromov Wasserstein](https://PythonOT.github.io/auto_examples/gromov/plot_gromov.html#compute-gw-with-a-scalable-stochastic-method-with-any-loss-function), Sinkhorn in log space with `method='sinkhorn_log'`, [Projection Robust Wasserstein](https://PythonOT.github.io/gen_modules/ot.dr.html?highlight=robust#ot.dr.projection_robust_wasserstein), -ans [debiased Sinkhorn barycenters](https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html). +ans [deviased Sinkorn barycenters](https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html). This release will also simplify the installation process. We have now a `pyproject.toml` that defines the build dependency and POT should now build even @@ -501,7 +368,7 @@ are coming for the next versions. #### Closed issues -- Add JMLR paper to the readme and Mathieu Blondel to the Acknowledgments (PR +- Add JMLR paper to the readme and Mathieu Blondel to the Acknoledgments (PR #231, #232) - Bug in Unbalanced OT example (Issue #127) - Clean Cython output when calling setup.py clean (Issue #122) @@ -509,7 +376,7 @@ are coming for the next versions. - EMD dimension mismatch (Issue #114, Fixed in PR #116) - 2D barycenter bug for non square images (Issue #124, fixed in PR #132) - Bad value in EMD 1D (Issue #138, fixed in PR #139) -- Log bugs for Gromov-Wasserstein solver (Issue #107, fixed in PR #108) +- Log bugs for Gromov-Wassertein solver (Issue #107, fixed in PR #108) - Weight issues in barycenter function (PR #106) ## 0.6.0 @@ -540,9 +407,9 @@ a solver for [Unbalanced OT barycenters](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_barycenter_1D.ipynb). A new variant of Gromov-Wasserstein divergence called [Fused Gromov-Wasserstein](https://pot.readthedocs.io/en/latest/all.html?highlight=fused_#ot.gromov.fused_gromov_wasserstein) -has been also contributed with examples of use on [structured +has been also contributed with exemples of use on [structured data](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb) and -computing [barycenters of labeled +computing [barycenters of labeld graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb). @@ -603,7 +470,7 @@ and [free support](https://github.com/rflamary/POT/blob/master/notebooks/plot_fr implementation of entropic OT. POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of -the unmaintained cudamat. Note that while we tried to keep changes to the +the unmaintained cudamat. Note that while we tried to keed changes to the minimum, the OTDA classes were deprecated. If you are happy with the cudamat implementation, we recommend you stay with stable release 0.4 for now. @@ -627,7 +494,7 @@ and new POT contributors (you can see the list in the [readme](https://github.co * Stochastic OT in the dual and semi-dual (PR #52 and PR #62) * Free support barycenters (PR #56) * Speed-up Sinkhorn function (PR #57 and PR #58) -* Add convolutional Wasserstein barycenters for 2D images (PR #64) +* Add convolutional Wassersein barycenters for 2D images (PR #64) * Add Greedy Sinkhorn variant (Greenkhorn) (PR #66) * Big ot.gpu update with cupy implementation (instead of un-maintained cudamat) (PR #67) @@ -678,7 +545,7 @@ This release contains a lot of contribution from new contributors. * new notebooks for emd computation and Wasserstein Discriminant Analysis * relocate notebooks * update documentation -* clean_zeros(a,b,M) for removing zeros in sparse distributions +* clean_zeros(a,b,M) for removimg zeros in sparse distributions * GPU implementations for sinkhorn and group lasso regularization @@ -686,7 +553,7 @@ This release contains a lot of contribution from new contributors. *7 Apr 2017* * New dimensionality reduction method (WDA) -* Efficient method emd2 returns only transport (in parallel if several histograms given) +* Efficient method emd2 returns only tarnsport (in paralell if several histograms given) @@ -727,4 +594,4 @@ It provides the following solvers: * Optimal transport for domain adaptation with group lasso regularization * Conditional gradient and Generalized conditional gradient for regularized OT. -Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. \ No newline at end of file +Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 265006d2c..73667b324 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -19,7 +19,8 @@ from .utils import list_to_array, get_parameter_pair -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', + reg_type="entropy", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -39,7 +40,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -67,8 +68,17 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -100,9 +110,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-unbalanced: References @@ -134,21 +143,21 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: @@ -156,8 +165,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', - numItermax=1000, stopThr=1e-6, verbose=False, - log=False, **kwargs): + reg_type="entropy", warmstart=None, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -175,7 +184,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -203,8 +212,17 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_reg_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_reg_scaling', see those function for specific parameterss + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -226,12 +244,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', -------- >>> import ot + >>> import numpy as np >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) - array([0.31912866]) - + >>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8) + 0.31912858 .. _references-sinkhorn-unbalanced2: References @@ -258,34 +276,60 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] ` """ - b = list_to_array(b) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) - if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) else: - raise ValueError('Unknown method %s.' % method) + if method.lower() == 'sinkhorn': + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the OT plan @@ -304,7 +348,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -330,6 +374,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -361,9 +414,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-knopp-unbalanced: References @@ -404,15 +456,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, 1), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, 1), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - K = nx.exp(M / (-reg)) + if reg_type == "kl": + K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] + elif reg_type == "entropy": + K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -474,9 +532,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, return u[:, None] * K * v[None, :] -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, - **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, tau=1e5, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -496,7 +555,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -523,6 +582,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). tau : float threshold for max value in u or v for log scaling numItermax : int, optional @@ -555,9 +623,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-stabilized-unbalanced: References @@ -597,16 +664,24 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - # print(reg) - K = nx.exp(-M / reg) + if reg_type == "kl": + log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :] + M0 = M - reg * log_ab + else: + M0 = M + + K = nx.exp(-M0 / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -641,7 +716,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 else: alpha = alpha + reg * nx.log(nx.max(u)) beta = beta + reg * nx.log(nx.max(v)) - K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + K = nx.exp((alpha[:, None] + beta[None, :] - M0) / reg) v = nx.ones(v.shape, type_as=v) Kv = nx.dot(K, v) @@ -687,7 +762,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 nx.log(M + 1e-100)[:, :, None] + logu[:, None, :] + logv[None, :, :] - - M[:, :, None] / reg, + - M0[:, :, None] / reg, axis=(0, 1) ) res = nx.exp(res) @@ -697,7 +772,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 return res else: # return OT matrix - ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M0 / reg) if log: return ot_matrix, log else: @@ -1074,7 +1149,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, raise ValueError("Unknown method '%s'." % method) -def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1084,7 +1159,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1094,6 +1169,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1113,8 +1189,11 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, then the same reg_m is applied to both marginal relaxations. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg : float, optional (default = 0) - Entropy regularization term >= 0. + Regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1172,36 +1251,33 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b - if G0 is None: - G = a[:, None] * b[None, :] - else: - G = G0 + G = a[:, None] * b[None, :] if G0 is None else G0 + c = a[:, None] * b[None, :] if c is None else c reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {'err': [], 'G': []} - if div == 'kl': - sum_r = reg + reg_m1 + reg_m2 - r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) - elif div == 'l2': - K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * a[:, None] * b[None, :] - M - K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) - else: + if div not in ["kl", "l2"]: warnings.warn("The div parameter should be either equal to 'kl' or \ 'l2': it has been set to 'kl'.") div = 'kl' + + if div == 'kl': sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) + K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) + elif div == 'l2': + K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M + K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) for i in range(numItermax): Gprev = G if div == 'kl': - G = K * G**(r1 + r2) / (nx.sum(G, 1, keepdims=True)**r1 * nx.sum(G, 0, keepdims=True)**r2 + 1e-16) + Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 + G = K * G**(r1 + r2) / Gd elif div == 'l2': Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 @@ -1223,7 +1299,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return G -def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1233,7 +1309,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1243,6 +1319,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1264,6 +1341,9 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, reg : float, optional (default = 0) Entropy regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1307,7 +1387,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - _, log_mm = mm_unbalanced(a, b, M, reg_m, reg=reg, div=div, G0=G0, + _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=True) @@ -1317,7 +1397,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return log_mm['cost'] -def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): +def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): """ return the loss function (scipy.optimize compatible) for regularized unbalanced OT @@ -1326,25 +1406,25 @@ def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='k m, n = M.shape def kl(p, q): - return np.sum(p * np.log(p / q + 1e-16)) - p.sum() + q.sum() + return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) def reg_l2(G): - return np.sum((G - a[:, None] * b[None, :])**2) / 2 + return np.sum((G - c)**2) / 2 def grad_l2(G): - return G - a[:, None] * b[None, :] + return G - c def reg_kl(G): - return kl(G, a[:, None] * b[None, :]) + return kl(G, c) def grad_kl(G): - return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + return np.log(G / c + 1e-16) def reg_entropy(G): - return np.sum(G * np.log(G + 1e-16)) + return np.sum(G * np.log(G + 1e-16)) - np.sum(G) def grad_entropy(G): - return np.log(G + 1e-16) + 1 + return np.log(G + 1e-16) if reg_div == 'kl': reg_fun = reg_kl @@ -1392,7 +1472,7 @@ def _func(G): return _func -def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, +def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. @@ -1400,7 +1480,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -1412,6 +1492,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize @@ -1426,6 +1507,9 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, loss matrix reg: float regularization term >=0 + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term >= 0, but cannot be infinity. If reg_m is a scalar or an indexable object of length 1, @@ -1433,7 +1517,8 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + Can take three values: 'entropy' (negative entropy), or + 'kl' (Kullback-Leibler) or 'l2' (quadratic). regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1482,19 +1567,15 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) - M0 = M + # convert to numpy a, b, M = nx.to_numpy(a, b, M) + G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) + c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) reg_m1, reg_m2 = get_parameter_pair(reg_m) - - if G0 is not None: - G0 = nx.to_numpy(G0) - else: - G0 = np.zeros(M.shape) - - _func = _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div, regm_div) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 272794cb8..7007e336b 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,8 +14,8 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(nx, method): +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -25,29 +25,32 @@ def test_unbalanced_convergence(nx, method): # make dists unbalanced b = ot.utils.unif(n) * 1.5 - M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + epsilon = 1. reg_m = 1. - a, b, M = nx.from_numpy(a, b, M) - - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method=method, - log=True, - verbose=True) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=True, verbose=True + ) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, verbose=True )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) logb = nx.log(b + 1e-16) loga = nx.log(a + 1e-16) - logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - + if reg_type == "entropy": + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + elif reg_type == "kl": + log_ab = loga[:, None] + logb[None, :] + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) @@ -65,15 +68,109 @@ def test_unbalanced_convergence(nx, method): a, b = nx.from_numpy(a_np, b_np) G = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, + method=method, reg_type=reg_type, verbose=True ) G_np = ot.unbalanced.sinkhorn_unbalanced( - a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, + method=method, reg_type=reg_type, verbose=True ) np.testing.assert_allclose(G_np, nx.to_numpy(G)) -@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_warmstart(nx, method, reg_type): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + G0, log0 = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=None, log=True, verbose=True + ) + loss0 = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=None, verbose=True + ) + + dim_a, dim_b = M.shape + warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + ) + loss = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, verbose=True + ) + + _, log_emd = ot.lp.emd(a, b, M, log=True) + warmstart1 = (log_emd["u"], log_emd["v"]) + G1, log1 = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True + ) + loss1 = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart1, verbose=True + ) + + np.testing.assert_allclose( + nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) + + +@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"], [True, False])) +def test_sinkhorn_unbalanced2(nx, method, reg_type, log): + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=False, verbose=True + )) + + res = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=log, verbose=True + ) + loss0 = res[0] if log else res + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + + +@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [1, float("inf")])) def test_unbalanced_relaxation_parameters(nx, method, reg_m): # test generalized sinkhorn for unbalanced OT n = 100 @@ -117,7 +214,7 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m): nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -135,11 +232,10 @@ def test_unbalanced_multiple_inputs(nx, method): a, b, M = nx.from_numpy(a, b, M) - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method=method, - log=True, - verbose=True) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, method=method, + log=True, verbose=True) + # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) @@ -394,6 +490,31 @@ def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div): np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_reference_measure(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + M = ot.dist(xs, xt) + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + c = a[:, None] * b[None, :] + + G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + + @pytest.mark.parametrize("div", ["kl", "l2"]) def test_mm_convergence(nx, div): n = 100 @@ -483,6 +604,41 @@ def test_mm_relaxation_parameters(nx, div): np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_reference_measure(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + c = a[:, None] * b[None, :] + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=False, log=True) + loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=True) + loss_0 = nx.to_numpy(loss_0) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, verbose=True) + loss_1 = nx.to_numpy(loss_1) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + def test_mm_wrong_divergence(nx): n = 100 From 1ece2d87c43de87607b6804162205c2351a08d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Nov 2023 13:25:41 +0100 Subject: [PATCH 4/4] fix conflict with RELEASES.md (#561) --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index 97834cdfe..223eb0116 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -13,6 +13,8 @@ + Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543) + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) ++ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) ++ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) + Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) #### Closed issues