Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz committed Oct 2, 2024
1 parent b1e9434 commit 527506a
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions test/gromov/test_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_partial_gromov_wasserstein(nx):
warn=True, verbose=True)

resb_ = nx.to_numpy(resb)
np.testing.assert_allclose(res, 0, atol=1e-1, rtol=1e-1)
np.testing.assert_allclose(res, resb_, atol=1e-15)
np.testing.assert_allclose(res, 0, rtol=1e-4)
np.testing.assert_allclose(res, resb_, rtol=1e-4)
assert np.all(res.sum(1) <= p) # cf convergence wasserstein
assert np.all(res.sum(0) <= q) # cf convergence wasserstein
np.testing.assert_allclose(
Expand All @@ -110,7 +110,7 @@ def test_partial_gromov_wasserstein(nx):
C1b, C1subb, p=pb, q=psubb, m=m, log=True)

resb_ = nx.to_numpy(resb)
np.testing.assert_allclose(res, resb_, atol=1e-15)
np.testing.assert_allclose(res, resb_, rtol=1e-4)
assert np.all(res.sum(1) <= p) # cf convergence wasserstein
assert np.all(res.sum(0) <= psub) # cf convergence wasserstein
np.testing.assert_allclose(
Expand All @@ -123,16 +123,16 @@ def test_partial_gromov_wasserstein(nx):
res0b, log0b = ot.gromov.partial_gromov_wasserstein(
C1b, C2b, pb, qb, m=None, log=True)
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
np.testing.assert_allclose(G, res0, atol=1e-04)
np.testing.assert_allclose(res0b, res0, atol=1e-04)
np.testing.assert_allclose(G, res0, rtol=1e-4)
np.testing.assert_allclose(res0b, res0, rtol=1e-4)

# tests for pGW2
for loss_fun in ['square_loss', 'kl_loss']:
w0, log0 = ot.gromov.partial_gromov_wasserstein2(
C1, C2, p=None, q=q, m=m, loss_fun=loss_fun, log=True)
w0_val = ot.gromov.partial_gromov_wasserstein2(
C1b, C2b, p=pb, q=None, m=m, loss_fun=loss_fun, log=False)
np.testing.assert_allclose(w0, w0_val, rtol=1e-8)
np.testing.assert_allclose(w0, w0_val, rtol=1e-4)

# tests integers
C1_int = C1.astype(int)
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_partial_partial_gromov_linesearch(nx):
G0b, deltaGb, cost_G0b, df_G0b, fC1, fC2, hC1, hC2, 0., 1.,
alpha_min=0., alpha_max=1.)

np.testing.assert_allclose(alpha, 1., atol=1e-2)
np.testing.assert_allclose(alpha, 1., rtol=1e-4)


@pytest.skip_backend("jax", reason="test very slow with jax backend")
Expand Down Expand Up @@ -252,12 +252,12 @@ def test_entropic_partial_gromov_wasserstein(nx):
symmetric=False, verbose=True)

resb_ = nx.to_numpy(resb)
np.testing.assert_allclose(res, 0, atol=1e-1, rtol=1e-1)
np.testing.assert_allclose(res, resb_, atol=1e-15)
np.testing.assert_allclose(res, 0, rtol=1e-4)
np.testing.assert_allclose(res, resb_, rtol=1e-4)
assert np.all(res.sum(1) <= p) # cf convergence wasserstein
assert np.all(res.sum(0) <= q) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(res), m, atol=1e-15)
np.sum(res), m, rtol=1e-4)

# tests with m is None
res = ot.gromov.entropic_partial_gromov_wasserstein(
Expand All @@ -272,7 +272,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
np.testing.assert_allclose(res, 0, atol=1e-1, rtol=1e-1)
np.testing.assert_allclose(res, resb_, atol=1e-7)
np.testing.assert_allclose(
np.sum(res), 1., atol=1e-8)
np.sum(res), 1., rtol=1e-4)

# tests with different number of samples across spaces
m = 0.5
Expand All @@ -283,11 +283,11 @@ def test_entropic_partial_gromov_wasserstein(nx):
C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True)

resb_ = nx.to_numpy(resb)
np.testing.assert_allclose(res, resb_, atol=1e-15)
np.testing.assert_allclose(res, resb_, rtol=1e-4)
assert np.all(res.sum(1) <= p) # cf convergence wasserstein
assert np.all(res.sum(0) <= psub) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(res), m, atol=1e-15)
np.sum(res), m, rtol=1e-4)

# tests for pGW2
for loss_fun in ['square_loss', 'kl_loss']:
Expand Down

0 comments on commit 527506a

Please sign in to comment.