Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pytorch-geometric style batching for the online backend #84

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions geomloss/samples_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,24 @@ def __init__(
self.potentials = potentials
self.verbose = verbose

def forward(self, *args):
def forward(self, *args, ptr_x=None, ptr_y=None):
"""Computes the loss between sampled measures.

Documentation and examples: Soon!
Until then, please check the tutorials :-)"""
Until then, please check the tutorials :-)

l_x, α, x, l_y, β, y = self.process_args(*args)
keyword only:

ptr_x (LongTensor, default=None): If **backend** is ``"online"``, specifies the batches pyg-style,
i.e. the pointer tensor that indicates the start of each new sample in the flattened **x** tensor.
If **None**, the routine will assume that the samples are concatenated along the first dimension.
ptr_y (LongTensor, default=None): If **backend** is ``"online"``, specifies the batches pyg-style,
i.e. the pointer tensor that indicates the start of each new sample in the flattened **y** tensor.
If **None**, the routine will assume that the samples are concatenated along the first dimension.

"""

l_x, α, x, l_y, β, y = self.process_args(*args, ptr_x=ptr_x, ptr_y=ptr_y)
B, N, M, D, l_x, α, l_y, β = self.check_shapes(l_x, α, x, l_y, β, y)

backend = (
Expand All @@ -227,7 +238,9 @@ def forward(self, *args):
)

elif backend == "auto":
if M * N <= 5000**2:
if ptr_x is not None or ptr_y is not None:
backend = "online"
elif M * N <= 5000**2:
backend = (
"tensorized" # Fast backend, with a quadratic memory footprint
)
Expand Down Expand Up @@ -260,6 +273,10 @@ def forward(self, *args):
]: # tensorized and online routines work on batched tensors
α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0)

if ptr_x is not None:
kwargs = {"ptr_x": ptr_x, "ptr_y": ptr_y}
else: kwargs = {}

# Run --------------------------------------------------------------------------------
values = routines[self.loss][backend](
α,
Expand All @@ -280,6 +297,7 @@ def forward(self, *args):
labels_x=l_x,
labels_y=l_y,
verbose=self.verbose,
**kwargs
)

# Make sure that the output has the correct shape ------------------------------------
Expand All @@ -288,7 +306,8 @@ def forward(self, *args):
): # Return some dual potentials (= test functions) sampled on the input measures
F, G = values
return F.view_as(α), G.view_as(β)

elif ptr_x is not None:
return values # The user expects a "batch vector" of distances
else: # Return a scalar cost value
if backend in ["multiscale"]: # KeOps backends return a single scalar value
if B == 0:
Expand All @@ -304,23 +323,29 @@ def forward(self, *args):
else:
return values # The user expects a "batch vector" of distances

def process_args(self, *args):
def process_args(self, *args, ptr_x=None, ptr_y=None):
if len(args) == 6:
return args
if len(args) == 4:
α, x, β, y = args
return None, α, x, None, β, y
elif len(args) == 2:
x, y = args
α = self.generate_weights(x)
β = self.generate_weights(y)
α = self.generate_weights(x, ptr_x)
β = self.generate_weights(y, ptr_y)
return None, α, x, None, β, y
else:
raise ValueError(
"A SamplesLoss accepts two (x, y), four (α, x, β, y) or six (l_x, α, x, l_y, β, y) arguments."
)

def generate_weights(self, x):
def generate_weights(self, x, ptr = None):
if ptr is not None:
with torch.no_grad():
batch = ptr[1:] - ptr[:-1]
weights = 1 / batch.type_as(x)
weights = torch.repeat_interleave(weights, batch)
return weights
if x.dim() == 2: #
N = x.shape[0]
return torch.ones(N).type_as(x) / N
Expand Down
72 changes: 51 additions & 21 deletions geomloss/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ def scaling_parameters(x, y, p, blur, reach, diameter, scaling):


def sinkhorn_cost(
eps, rho, a, b, f_aa, g_bb, g_ab, f_ba, batch=False, debias=True, potentials=False
eps, rho, a, b, f_aa, g_bb, g_ab, f_ba, batch=False, debias=True, potentials=False,
batch_ranges_xy=None,
batch_ranges_yx=None,
batch_ranges_xx=None,
batch_ranges_yy=None,
):
r"""Returns the required information (cost, etc.) from a set of dual potentials.

Expand All @@ -189,7 +193,8 @@ def sinkhorn_cost(
Defaults to True.
potentials (bool, optional): Shall we return the dual vectors instead of the cost value?
Defaults to False.

batch_ranges_xy, batch_ranges_yx, batch_ranges_xx, batch_ranges_yy:
see the documentation of the `sinkhorn_loop` function.
Returns:
Tensor or pair of Tensors: if `potentials` is True, we return a pair
of (..., N), (..., M) Tensors that encode the optimal dual vectors,
Expand All @@ -211,8 +216,8 @@ def sinkhorn_cost(
): # UNBIASED Sinkhorn divergence, S_eps(a,b) = OT_eps(a,b) - .5*OT_eps(a,a) - .5*OT_eps(b,b)
if rho is None: # Balanced case:
# See Eq. (3.209) in Jean Feydy's PhD thesis.
return scal(a, f_ba - f_aa, batch=batch) + scal(
b, g_ab - g_bb, batch=batch
return scal(a, f_ba - f_aa, batch=batch, ranges=batch_ranges_xx) + scal(
b, g_ab - g_bb, batch=batch, ranges=batch_ranges_yy
)
else:
# Unbalanced case:
Expand All @@ -222,21 +227,25 @@ def sinkhorn_cost(
return scal(
a,
UnbalancedWeight(eps, rho)(
(-f_aa / rho).exp() - (-f_ba / rho).exp()
(-f_aa / rho).exp() - (-f_ba / rho).exp(),
),
batch=batch,
ranges=batch_ranges_xx,
) + scal(
b,
UnbalancedWeight(eps, rho)(
(-g_bb / rho).exp() - (-g_ab / rho).exp()
(-g_bb / rho).exp() - (-g_ab / rho).exp(),

),
batch=batch,
ranges=batch_ranges_yy,
)

else: # Classic, BIASED entropized Optimal Transport OT_eps(a,b)
if rho is None: # Balanced case:
# See Eq. (3.207) in Jean Feydy's PhD thesis.
return scal(a, f_ba, batch=batch) + scal(b, g_ab, batch=batch)
return scal(a, f_ba, batch=batch, ranges=batch_ranges_xx) \
+ scal(b, g_ab, batch=batch, ranges=batch_ranges_yy)
else:
# Unbalanced case:
# See Proposition 12 (Dual formulas for the Sinkhorn costs)
Expand All @@ -245,9 +254,11 @@ def sinkhorn_cost(
# N.B.: Even if this quantity is never used in practice,
# we may want to re-check this computation...
return scal(
a, UnbalancedWeight(eps, rho)(1 - (-f_ba / rho).exp()), batch=batch
a, UnbalancedWeight(eps, rho)(1 - (-f_ba / rho).exp()),
batch=batch, ranges=batch_ranges_xx
) + scal(
b, UnbalancedWeight(eps, rho)(1 - (-g_ab / rho).exp()), batch=batch
b, UnbalancedWeight(eps, rho)(1 - (-g_ab / rho).exp()),
batch=batch, ranges=batch_ranges_yy
)


Expand All @@ -273,6 +284,11 @@ def sinkhorn_loop(
extrapolate=None,
debias=True,
last_extrapolation=True,
batch_ranges_xy=None,
batch_ranges_yx=None,
batch_ranges_xx=None,
batch_ranges_yy=None,

):
r"""Implements the (possibly multiscale) symmetric Sinkhorn loop,
with the epsilon-scaling (annealing) heuristic.
Expand Down Expand Up @@ -387,6 +403,14 @@ def sinkhorn_loop(
to backpropagate trough the full Sinkhorn loop.
Defaults to True.

batch_ranges_xy (list of Tensors, optional), batch_ranges_yx (list of Tensors, optional),
batch_ranges_xx (list of Tensors, optional), batch_ranges_yy (list of Tensors, optional):
List of ranges that encode which points can be matched together
in the input measures. Useful for pytorch-geometric style batching.
Defaults to None.
Given in KeOps format, see https://www.kernel-operations.io/keops/python/api/numpy/Genred_numpy.html#pykeops.numpy.Genred.__call__
Only compatible with the "single-scale" mode for now!

Returns:
4-uple of Tensors: The four optimal dual potentials
`(f_aa, g_bb, g_ab, f_ba)` that are respectively
Expand Down Expand Up @@ -415,6 +439,12 @@ def sinkhorn_loop(
# Cost "matrices" C(x_i, x_j) and C(y_i, y_j):
if debias: # Only used for the "a <-> a" and "b <-> b" problems.
C_xxs, C_yys = [C_xxs], [C_yys]
if batch_ranges_xx or batch_ranges_xy or batch_ranges_yx or batch_ranges_yy:
assert len(a_logs) == 1, "Batching is only compatible with single-scale mode for now."
kwargs_xy = {"ranges": batch_ranges_xy} if batch_ranges_xy is not None else {}
kwargs_yx = {"ranges": batch_ranges_yx} if batch_ranges_yx is not None else {}
kwargs_xx = {"ranges": batch_ranges_xx} if batch_ranges_xx is not None else {}
kwargs_yy = {"ranges": batch_ranges_yy} if batch_ranges_yy is not None else {}

# N.B.: We don't let users backprop through the Sinkhorn iterations
# and branch instead on an explicit formula "at convergence"
Expand Down Expand Up @@ -459,11 +489,11 @@ def sinkhorn_loop(
# a convolution with the cost function (i.e. the limit for eps=+infty).
# The algorithm was originally written with this convolution
# - but in this implementation, we use "softmin" for the sake of simplicity.
g_ab = damping * softmin(eps, C_yx, a_log) # a -> b
f_ba = damping * softmin(eps, C_xy, b_log) # b -> a
g_ab = damping * softmin(eps, C_yx, a_log, **kwargs_yx) # a -> b
f_ba = damping * softmin(eps, C_xy, b_log, **kwargs_xy) # b -> a
if debias:
f_aa = damping * softmin(eps, C_xx, a_log) # a -> a
g_bb = damping * softmin(eps, C_yy, b_log) # a -> a
f_aa = damping * softmin(eps, C_xx, a_log, **kwargs_xx) # a -> a
g_bb = damping * softmin(eps, C_yy, b_log, **kwargs_yy) # a -> a

# Lines 4-5: eps-scaling descent ---------------------------------------------------
for i, eps in enumerate(eps_list): # See Fig. 3.25-26 in Jean Feydy's PhD thesis.
Expand All @@ -478,15 +508,15 @@ def sinkhorn_loop(
# (for "f-tilde", "g-tilde") using the standard
# Sinkhorn formulas, and update both dual vectors
# simultaneously.
ft_ba = damping * softmin(eps, C_xy, b_log + g_ab / eps) # b -> a
gt_ab = damping * softmin(eps, C_yx, a_log + f_ba / eps) # a -> b
ft_ba = damping * softmin(eps, C_xy, b_log + g_ab / eps, **kwargs_xy) # b -> a
gt_ab = damping * softmin(eps, C_yx, a_log + f_ba / eps, **kwargs_yx) # a -> b

# See Fig. 3.21 in Jean Feydy's PhD thesis to see the importance
# of debiasing when the target "blur" or "eps**(1/p)" value is larger
# than the average distance between samples x_i, y_j and their neighbours.
if debias:
ft_aa = damping * softmin(eps, C_xx, a_log + f_aa / eps) # a -> a
gt_bb = damping * softmin(eps, C_yy, b_log + g_bb / eps) # b -> b
ft_aa = damping * softmin(eps, C_xx, a_log + f_aa / eps, **kwargs_xx) # a -> a
gt_bb = damping * softmin(eps, C_yy, b_log + g_bb / eps, **kwargs_yy) # b -> b

# Symmetrized updates - see Fig. 3.24.b in Jean Feydy's PhD thesis:
f_ba, g_ab = 0.5 * (f_ba + ft_ba), 0.5 * (g_ab + gt_ab) # OT(a,b) wrt. a, b
Expand Down Expand Up @@ -615,13 +645,13 @@ def sinkhorn_loop(
if last_extrapolation:
# The cross-updates should be done in parallel!
f_ba, g_ab = (
damping * softmin(eps, C_xy, (b_log + g_ab / eps).detach()),
damping * softmin(eps, C_yx, (a_log + f_ba / eps).detach()),
damping * softmin(eps, C_xy, (b_log + g_ab / eps).detach(), **kwargs_xy),
damping * softmin(eps, C_yx, (a_log + f_ba / eps).detach(), **kwargs_yx),
)

if debias:
f_aa = damping * softmin(eps, C_xx, (a_log + f_aa / eps).detach())
g_bb = damping * softmin(eps, C_yy, (b_log + g_bb / eps).detach())
f_aa = damping * softmin(eps, C_xx, (a_log + f_aa / eps).detach(), **kwargs_xx)
g_bb = damping * softmin(eps, C_yy, (b_log + g_bb / eps).detach(), **kwargs_yy)

if debias:
return f_aa, g_bb, g_ab, f_ba
Expand Down
Loading