Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Gaussian initialization for sinkhorn #555

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017.

[60] Thornton, James, and Marco Cuturi. [Rethinking initialization of the sinkhorn algorithm](https://arxiv.org/pdf/2206.07630.pdf). International Conference on Artificial Intelligence and Statistics. PMLR, 2023.
3 changes: 2 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
binary_search_circle, wasserstein_circle,
semidiscrete_wasserstein2_unif_circle)
from .bregman import sinkhorn, sinkhorn2, barycenter
from .bregman import (sinkhorn, sinkhorn2, barycenter, empirical_sinkhorn, empirical_sinkhorn2, empirical_sinkhorn_divergence)
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
Expand All @@ -61,6 +61,7 @@
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn_divergence',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
Expand Down
37 changes: 31 additions & 6 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ot.utils import dist, list_to_array, unif

from .backend import get_backend
from .gaussian import dual_gaussian_init


def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
Expand Down Expand Up @@ -541,6 +542,7 @@
log['niter'] = ii
log['u'] = u
log['v'] = v
log['warmstart'] = (nx.log(u), nx.log(v))

if n_hists: # return only loss
res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
Expand Down Expand Up @@ -697,6 +699,7 @@
'log_v': nx.stack(lst_v, 1), }
log['u'] = nx.exp(log['log_u'])
log['v'] = nx.exp(log['log_v'])
log['warmstart'] = (log['log_u'], log['log_v'])
return res, log
else:
return res
Expand Down Expand Up @@ -2999,15 +3002,23 @@
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

if warmstart is None:
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
elif warmstart == 'gaussian':
# init only g since f is the first updated
f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None])
g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None])

Check warning on line 3010 in ot/bregman.py

View check run for this annotation

Codecov / codecov/patch

ot/bregman.py#L3009-L3010

Added lines #L3009 - L3010 were not covered by tests
elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2:
f, g = warmstart
else:
raise ValueError(

Check warning on line 3014 in ot/bregman.py

View check run for this annotation

Codecov / codecov/patch

ot/bregman.py#L3014

Added line #L3014 was not covered by tests
"warmstart must be None, 'gaussian' or a tuple of two arrays")

if isLazy:
if log:
dict_log = {"err": []}

log_a, log_b = nx.log(a), nx.log(b)
if warmstart is None:
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
else:
f, g = warmstart

if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
Expand Down Expand Up @@ -3075,6 +3086,7 @@
if log:
dict_log["u"] = f
dict_log["v"] = g
dict_log["warmstart"] = (f, g)
return (f, g, dict_log)
else:
return (f, g)
Expand All @@ -3083,11 +3095,11 @@
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
verbose=verbose, log=True, warmstart=warmstart, **kwargs)
verbose=verbose, log=True, warmstart=(f, g), **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
verbose=verbose, log=False, warmstart=warmstart, **kwargs)
verbose=verbose, log=False, warmstart=(f, g), **kwargs)
return pi


Expand Down Expand Up @@ -3201,6 +3213,19 @@
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

if warmstart is None:
warmstart = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
elif warmstart == 'gaussian':
# init only g since f is the first updated
f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None])
g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None])
warmstart = (f, g)
elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2:
warmstart = warmstart
else:
raise ValueError(

Check warning on line 3226 in ot/bregman.py

View check run for this annotation

Codecov / codecov/patch

ot/bregman.py#L3226

Added line #L3226 was not covered by tests
"warmstart must be None, 'gaussian' or a tuple of two arrays")

if isLazy:
if log:
f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
Expand Down
48 changes: 48 additions & 0 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,3 +645,51 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None,
return A, b, log
else:
return A, b


def dual_gaussian_init(xs, xt, ws=None, wt=None, reg=1e-6):
r""" Return the source dual potential gaussian initialization.

This function return the dual potential gaussian initialization that can be
used to initialize the Sinkhorn algorithm. This initialization is based on
the Monge mapping between the source and target distributions seen as two
Gaussian distributions [60].

Parameters
----------
xs : array-like (ns,ds)
samples in the source domain
xt : array-like (nt,dt)
samples in the target domain
ws : array-like (ns,1), optional
weights for the source samples
wt : array-like (ns,1), optional
weights for the target samples
reg : float,optional
regularization added to the diagonals of covariances (>0)

.. [60] Thornton, James, and Marco Cuturi. "Rethinking initialization of the
sinkhorn algorithm." International Conference on Artificial Intelligence
and Statistics. PMLR, 2023.
"""

nx = get_backend(xs, xt)

if ws is None:
ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]

if wt is None:
wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]

# estimate mean and covariance
mu_s = nx.dot(ws.T, xs) / nx.sum(ws)
mu_t = nx.dot(wt.T, xt) / nx.sum(wt)

A, b = empirical_bures_wasserstein_mapping(xs, xt, ws=ws, wt=wt, reg=reg)

xsc = xs - mu_s

# compute the dual potential (see appendix D in [60])
f = nx.sum(xs**2 - nx.dot(xsc, A) * xsc - mu_t * xs, 1)

return f
8 changes: 8 additions & 0 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,9 @@ def test_empirical_sinkhorn(nx):
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))

loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy(
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian'))

# check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
Expand All @@ -1055,6 +1058,7 @@ def test_empirical_sinkhorn(nx):
np.testing.assert_allclose(
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05)


def test_lazy_empirical_sinkhorn(nx):
Expand Down Expand Up @@ -1095,6 +1099,9 @@ def test_lazy_empirical_sinkhorn(nx):
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))

loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy(
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian', isLazy=True))

# check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
Expand All @@ -1109,6 +1116,7 @@ def test_lazy_empirical_sinkhorn(nx):
np.testing.assert_allclose(
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05)


def test_empirical_sinkhorn_divergence(nx):
Expand Down
19 changes: 19 additions & 0 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,22 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target):

if d_target >= 2:
np.testing.assert_allclose(Cs, Ctt)


def test_gaussian_init(nx):
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)

a_s = np.ones((ns, 1)) / ns
a_t = np.ones((nt, 1)) / nt

Xsb, Xtb, a_sb, a_tb = nx.from_numpy(Xs, Xt, a_s, a_t)

f = ot.gaussian.dual_gaussian_init(Xsb, Xtb)

f2 = ot.gaussian.dual_gaussian_init(Xsb, Xtb, a_sb, a_tb)

np.testing.assert_allclose(nx.to_numpy(f), nx.to_numpy(f2))
Loading