Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check for NaNs in emd loss matrix #623

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
8 changes: 8 additions & 0 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c

.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.

.. note:: An error will be raided if the loss matrix :math:`\mathbf{M}` contains NaNs.

Uses the algorithm proposed in :ref:`[1] <references-emd>`.

Parameters
Expand Down Expand Up @@ -324,6 +326,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
# convert to numpy
M, a, b = nx.to_numpy(M, a, b)

if np.isnan(M).any():
raise ValueError('The loss matrix should not contain NaN values.')

# ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
Expand Down Expand Up @@ -502,6 +507,9 @@ def emd2(a, b, M, processes=1,
# convert to numpy
M, a, b = nx.to_numpy(M, a, b)

if np.isnan(M).any():
raise ValueError('The loss matrix should not contain NaN values.')

a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64, order='C')
Expand Down
15 changes: 15 additions & 0 deletions test/gromov/test_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,18 @@ def test_fgw_barycenter(nx):

np.testing.assert_allclose(C, Cb, atol=1e-06)
np.testing.assert_allclose(X, Xb, atol=1e-06)


# Related to issue 469
def test_gromov2_nan_in_target_cost():
# GIVEN - a target cost matrix with a NaN value
source_cost = np.zeros((2, 2))
target_cost = np.ones((2, 2))
source_distribution = np.array([0.5, 0.5])
target_distribution = np.array([0.5, 0.5])

target_cost[0, 0] = np.nan

# WHEN - we call
with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'):
ot.gromov_wasserstein2(source_cost, target_cost, source_distribution, target_distribution)
5 changes: 5 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def test_emd_empty():
np.testing.assert_allclose(w, 0)


def test_emd_nan_in_loss_matrix():
with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'):
ot.emd([], [], [np.nan])


def test_emd2_multi():
n = 500 # nb bins

Expand Down
Loading