Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Update _gw.py #637

Merged
merged 2 commits into from
Jun 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric

Where :

- :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- :math:`\mathbf{p}`: distribution in the source space
- :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
- :math:`\mathbf{C_1}`: Metric cost matrix in the source space.
- :math:`\mathbf{C_2}`: Metric cost matrix in the target space.
- :math:`\mathbf{p}`: Distribution in the source space.
- :math:`\mathbf{q}`: Distribution in the target space.
- `L`: Loss function to account for the misfit between the similarity matrices.

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
Expand All @@ -62,39 +62,39 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
Parameters
----------
C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
Metric cost matrix in the source space.
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
Metric cost matrix in the target space.
p : array-like, shape (ns,), optional
Distribution in the source space.
If let to its default value None, uniform distribution is taken.
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'
Loss function used for the solver either 'square_loss' or 'kl_loss'.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
verbose : bool, optional
Print information along iterations
Print information along iterations.
log : bool, optional
record log if True
Record log if True.
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
If True, the step of the line-search is found via an armijo search. Else closed form is used.
If there are convergence issues, use False.
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
If None, the initial transport plan of the solver is pq^T.
Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
max_iter : int, optional
Max number of iterations
Max number of iterations.
tol_rel : float, optional
Stop threshold on relative error (>0)
Stop threshold on relative error (>0).
tol_abs : float, optional
Stop threshold on absolute error (>0)
Stop threshold on absolute error (>0).
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
Parameters can be directly passed to the ot.optim.cg solver.

Returns
-------
Expand Down Expand Up @@ -175,7 +175,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):

if not nx.is_floating_point(C10):
warnings.warn(
"Input structure matrix consists of integer. The transport plan will be "
"Input structure matrix consists of integers. The transport plan will be "
"casted accordingly, possibly resulting in a loss of precision. "
"If this behaviour is unwanted, please make sure your input "
"structure matrix consists of floating point elements.",
Expand Down
Loading