diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 752c5d2d7..cec98de27 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -237,6 +237,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + .. note:: An error will be raided if the loss matrix :math:`\mathbf{M}` contains NaNs. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -324,6 +326,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c # convert to numpy M, a, b = nx.to_numpy(M, a, b) + if np.isnan(M).any(): + raise ValueError('The loss matrix should not contain NaN values.') + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -502,6 +507,9 @@ def emd2(a, b, M, processes=1, # convert to numpy M, a, b = nx.to_numpy(M, a, b) + if np.isnan(M).any(): + raise ValueError('The loss matrix should not contain NaN values.') + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order='C') diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 5b858f307..d8026352e 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -910,3 +910,18 @@ def test_fgw_barycenter(nx): np.testing.assert_allclose(C, Cb, atol=1e-06) np.testing.assert_allclose(X, Xb, atol=1e-06) + + +# Related to issue 469 +def test_gromov2_nan_in_target_cost(): + # GIVEN - a target cost matrix with a NaN value + source_cost = np.zeros((2, 2)) + target_cost = np.ones((2, 2)) + source_distribution = np.array([0.5, 0.5]) + target_distribution = np.array([0.5, 0.5]) + + target_cost[0, 0] = np.nan + + # WHEN - we call + with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'): + ot.gromov_wasserstein2(source_cost, target_cost, source_distribution, target_distribution) diff --git a/test/test_ot.py b/test/test_ot.py index a90321d5f..6096061cb 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -245,6 +245,11 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_nan_in_loss_matrix(): + with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'): + ot.emd([], [], [np.nan]) + + def test_emd2_multi(): n = 500 # nb bins