-
Notifications
You must be signed in to change notification settings - Fork 502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: check for NaNs in emd loss matrix #623
base: master
Are you sure you want to change the base?
Changes from 3 commits
fb5bb0c
fc53a26
9942d1e
0bf7dfa
1336ed2
be8a5ea
727d01d
79d00b9
b75d07c
1713360
55e7ac8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -237,6 +237,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c | |
.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. | ||
.. note:: An error will be raided if the loss matrix :math:`\mathbf{M}` contains NaNs. | ||
Uses the algorithm proposed in :ref:`[1] <references-emd>`. | ||
Parameters | ||
|
@@ -302,6 +304,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c | |
ot.optim.cg : General regularized OT | ||
""" | ||
|
||
if np.isnan(M).any(): | ||
raise ValueError('The loss matrix should not contain NaN values.') | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Failing early here ensures that we do not segfault in the accelerated I did not look too deep into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the graph is disconnected then the parts that are not used should have an infinite value (which is ha,ndled by the C++ solver). i'm OK with not handling naNs. |
||
a, b, M = list_to_array(a, b, M) | ||
nx = get_backend(M, a, b) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A problem here is that you are using numpy on arrays that might not be numpy (see backend function below). You should do the test later in the function on the OT loss marix that hhas been converted to numpy to avoid backend errors.