diff --git a/geomloss/samples_loss.py b/geomloss/samples_loss.py index 776c5dd..1cc40c0 100644 --- a/geomloss/samples_loss.py +++ b/geomloss/samples_loss.py @@ -206,13 +206,24 @@ def __init__( self.potentials = potentials self.verbose = verbose - def forward(self, *args): + def forward(self, *args, ptr_x=None, ptr_y=None): """Computes the loss between sampled measures. Documentation and examples: Soon! - Until then, please check the tutorials :-)""" + Until then, please check the tutorials :-) - l_x, α, x, l_y, β, y = self.process_args(*args) + keyword only: + + ptr_x (LongTensor, default=None): If **backend** is ``"online"``, specifies the batches pyg-style, + i.e. the pointer tensor that indicates the start of each new sample in the flattened **x** tensor. + If **None**, the routine will assume that the samples are concatenated along the first dimension. + ptr_y (LongTensor, default=None): If **backend** is ``"online"``, specifies the batches pyg-style, + i.e. the pointer tensor that indicates the start of each new sample in the flattened **y** tensor. + If **None**, the routine will assume that the samples are concatenated along the first dimension. + + """ + + l_x, α, x, l_y, β, y = self.process_args(*args, ptr_x=ptr_x, ptr_y=ptr_y) B, N, M, D, l_x, α, l_y, β = self.check_shapes(l_x, α, x, l_y, β, y) backend = ( @@ -227,7 +238,9 @@ def forward(self, *args): ) elif backend == "auto": - if M * N <= 5000**2: + if ptr_x is not None or ptr_y is not None: + backend = "online" + elif M * N <= 5000**2: backend = ( "tensorized" # Fast backend, with a quadratic memory footprint ) @@ -260,6 +273,10 @@ def forward(self, *args): ]: # tensorized and online routines work on batched tensors α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0) + if ptr_x is not None: + kwargs = {"ptr_x": ptr_x, "ptr_y": ptr_y} + else: kwargs = {} + # Run -------------------------------------------------------------------------------- values = routines[self.loss][backend]( α, @@ -280,6 +297,7 @@ def forward(self, *args): labels_x=l_x, labels_y=l_y, verbose=self.verbose, + **kwargs ) # Make sure that the output has the correct shape ------------------------------------ @@ -288,7 +306,8 @@ def forward(self, *args): ): # Return some dual potentials (= test functions) sampled on the input measures F, G = values return F.view_as(α), G.view_as(β) - + elif ptr_x is not None: + return values # The user expects a "batch vector" of distances else: # Return a scalar cost value if backend in ["multiscale"]: # KeOps backends return a single scalar value if B == 0: @@ -304,7 +323,7 @@ def forward(self, *args): else: return values # The user expects a "batch vector" of distances - def process_args(self, *args): + def process_args(self, *args, ptr_x=None, ptr_y=None): if len(args) == 6: return args if len(args) == 4: @@ -312,15 +331,21 @@ def process_args(self, *args): return None, α, x, None, β, y elif len(args) == 2: x, y = args - α = self.generate_weights(x) - β = self.generate_weights(y) + α = self.generate_weights(x, ptr_x) + β = self.generate_weights(y, ptr_y) return None, α, x, None, β, y else: raise ValueError( "A SamplesLoss accepts two (x, y), four (α, x, β, y) or six (l_x, α, x, l_y, β, y) arguments." ) - def generate_weights(self, x): + def generate_weights(self, x, ptr = None): + if ptr is not None: + with torch.no_grad(): + batch = ptr[1:] - ptr[:-1] + weights = 1 / batch.type_as(x) + weights = torch.repeat_interleave(weights, batch) + return weights if x.dim() == 2: # N = x.shape[0] return torch.ones(N).type_as(x) / N diff --git a/geomloss/sinkhorn_divergence.py b/geomloss/sinkhorn_divergence.py index ac6974a..1ab0c82 100644 --- a/geomloss/sinkhorn_divergence.py +++ b/geomloss/sinkhorn_divergence.py @@ -170,7 +170,11 @@ def scaling_parameters(x, y, p, blur, reach, diameter, scaling): def sinkhorn_cost( - eps, rho, a, b, f_aa, g_bb, g_ab, f_ba, batch=False, debias=True, potentials=False + eps, rho, a, b, f_aa, g_bb, g_ab, f_ba, batch=False, debias=True, potentials=False, + batch_ranges_xy=None, + batch_ranges_yx=None, + batch_ranges_xx=None, + batch_ranges_yy=None, ): r"""Returns the required information (cost, etc.) from a set of dual potentials. @@ -189,7 +193,8 @@ def sinkhorn_cost( Defaults to True. potentials (bool, optional): Shall we return the dual vectors instead of the cost value? Defaults to False. - + batch_ranges_xy, batch_ranges_yx, batch_ranges_xx, batch_ranges_yy: + see the documentation of the `sinkhorn_loop` function. Returns: Tensor or pair of Tensors: if `potentials` is True, we return a pair of (..., N), (..., M) Tensors that encode the optimal dual vectors, @@ -211,8 +216,8 @@ def sinkhorn_cost( ): # UNBIASED Sinkhorn divergence, S_eps(a,b) = OT_eps(a,b) - .5*OT_eps(a,a) - .5*OT_eps(b,b) if rho is None: # Balanced case: # See Eq. (3.209) in Jean Feydy's PhD thesis. - return scal(a, f_ba - f_aa, batch=batch) + scal( - b, g_ab - g_bb, batch=batch + return scal(a, f_ba - f_aa, batch=batch, ranges=batch_ranges_xx) + scal( + b, g_ab - g_bb, batch=batch, ranges=batch_ranges_yy ) else: # Unbalanced case: @@ -222,21 +227,25 @@ def sinkhorn_cost( return scal( a, UnbalancedWeight(eps, rho)( - (-f_aa / rho).exp() - (-f_ba / rho).exp() + (-f_aa / rho).exp() - (-f_ba / rho).exp(), ), batch=batch, + ranges=batch_ranges_xx, ) + scal( b, UnbalancedWeight(eps, rho)( - (-g_bb / rho).exp() - (-g_ab / rho).exp() + (-g_bb / rho).exp() - (-g_ab / rho).exp(), + ), batch=batch, + ranges=batch_ranges_yy, ) else: # Classic, BIASED entropized Optimal Transport OT_eps(a,b) if rho is None: # Balanced case: # See Eq. (3.207) in Jean Feydy's PhD thesis. - return scal(a, f_ba, batch=batch) + scal(b, g_ab, batch=batch) + return scal(a, f_ba, batch=batch, ranges=batch_ranges_xx) \ + + scal(b, g_ab, batch=batch, ranges=batch_ranges_yy) else: # Unbalanced case: # See Proposition 12 (Dual formulas for the Sinkhorn costs) @@ -245,9 +254,11 @@ def sinkhorn_cost( # N.B.: Even if this quantity is never used in practice, # we may want to re-check this computation... return scal( - a, UnbalancedWeight(eps, rho)(1 - (-f_ba / rho).exp()), batch=batch + a, UnbalancedWeight(eps, rho)(1 - (-f_ba / rho).exp()), + batch=batch, ranges=batch_ranges_xx ) + scal( - b, UnbalancedWeight(eps, rho)(1 - (-g_ab / rho).exp()), batch=batch + b, UnbalancedWeight(eps, rho)(1 - (-g_ab / rho).exp()), + batch=batch, ranges=batch_ranges_yy ) @@ -273,6 +284,11 @@ def sinkhorn_loop( extrapolate=None, debias=True, last_extrapolation=True, + batch_ranges_xy=None, + batch_ranges_yx=None, + batch_ranges_xx=None, + batch_ranges_yy=None, + ): r"""Implements the (possibly multiscale) symmetric Sinkhorn loop, with the epsilon-scaling (annealing) heuristic. @@ -387,6 +403,14 @@ def sinkhorn_loop( to backpropagate trough the full Sinkhorn loop. Defaults to True. + batch_ranges_xy (list of Tensors, optional), batch_ranges_yx (list of Tensors, optional), + batch_ranges_xx (list of Tensors, optional), batch_ranges_yy (list of Tensors, optional): + List of ranges that encode which points can be matched together + in the input measures. Useful for pytorch-geometric style batching. + Defaults to None. + Given in KeOps format, see https://www.kernel-operations.io/keops/python/api/numpy/Genred_numpy.html#pykeops.numpy.Genred.__call__ + Only compatible with the "single-scale" mode for now! + Returns: 4-uple of Tensors: The four optimal dual potentials `(f_aa, g_bb, g_ab, f_ba)` that are respectively @@ -415,6 +439,12 @@ def sinkhorn_loop( # Cost "matrices" C(x_i, x_j) and C(y_i, y_j): if debias: # Only used for the "a <-> a" and "b <-> b" problems. C_xxs, C_yys = [C_xxs], [C_yys] + if batch_ranges_xx or batch_ranges_xy or batch_ranges_yx or batch_ranges_yy: + assert len(a_logs) == 1, "Batching is only compatible with single-scale mode for now." + kwargs_xy = {"ranges": batch_ranges_xy} if batch_ranges_xy is not None else {} + kwargs_yx = {"ranges": batch_ranges_yx} if batch_ranges_yx is not None else {} + kwargs_xx = {"ranges": batch_ranges_xx} if batch_ranges_xx is not None else {} + kwargs_yy = {"ranges": batch_ranges_yy} if batch_ranges_yy is not None else {} # N.B.: We don't let users backprop through the Sinkhorn iterations # and branch instead on an explicit formula "at convergence" @@ -459,11 +489,11 @@ def sinkhorn_loop( # a convolution with the cost function (i.e. the limit for eps=+infty). # The algorithm was originally written with this convolution # - but in this implementation, we use "softmin" for the sake of simplicity. - g_ab = damping * softmin(eps, C_yx, a_log) # a -> b - f_ba = damping * softmin(eps, C_xy, b_log) # b -> a + g_ab = damping * softmin(eps, C_yx, a_log, **kwargs_yx) # a -> b + f_ba = damping * softmin(eps, C_xy, b_log, **kwargs_xy) # b -> a if debias: - f_aa = damping * softmin(eps, C_xx, a_log) # a -> a - g_bb = damping * softmin(eps, C_yy, b_log) # a -> a + f_aa = damping * softmin(eps, C_xx, a_log, **kwargs_xx) # a -> a + g_bb = damping * softmin(eps, C_yy, b_log, **kwargs_yy) # a -> a # Lines 4-5: eps-scaling descent --------------------------------------------------- for i, eps in enumerate(eps_list): # See Fig. 3.25-26 in Jean Feydy's PhD thesis. @@ -478,15 +508,15 @@ def sinkhorn_loop( # (for "f-tilde", "g-tilde") using the standard # Sinkhorn formulas, and update both dual vectors # simultaneously. - ft_ba = damping * softmin(eps, C_xy, b_log + g_ab / eps) # b -> a - gt_ab = damping * softmin(eps, C_yx, a_log + f_ba / eps) # a -> b + ft_ba = damping * softmin(eps, C_xy, b_log + g_ab / eps, **kwargs_xy) # b -> a + gt_ab = damping * softmin(eps, C_yx, a_log + f_ba / eps, **kwargs_yx) # a -> b # See Fig. 3.21 in Jean Feydy's PhD thesis to see the importance # of debiasing when the target "blur" or "eps**(1/p)" value is larger # than the average distance between samples x_i, y_j and their neighbours. if debias: - ft_aa = damping * softmin(eps, C_xx, a_log + f_aa / eps) # a -> a - gt_bb = damping * softmin(eps, C_yy, b_log + g_bb / eps) # b -> b + ft_aa = damping * softmin(eps, C_xx, a_log + f_aa / eps, **kwargs_xx) # a -> a + gt_bb = damping * softmin(eps, C_yy, b_log + g_bb / eps, **kwargs_yy) # b -> b # Symmetrized updates - see Fig. 3.24.b in Jean Feydy's PhD thesis: f_ba, g_ab = 0.5 * (f_ba + ft_ba), 0.5 * (g_ab + gt_ab) # OT(a,b) wrt. a, b @@ -615,13 +645,13 @@ def sinkhorn_loop( if last_extrapolation: # The cross-updates should be done in parallel! f_ba, g_ab = ( - damping * softmin(eps, C_xy, (b_log + g_ab / eps).detach()), - damping * softmin(eps, C_yx, (a_log + f_ba / eps).detach()), + damping * softmin(eps, C_xy, (b_log + g_ab / eps).detach(), **kwargs_xy), + damping * softmin(eps, C_yx, (a_log + f_ba / eps).detach(), **kwargs_yx), ) if debias: - f_aa = damping * softmin(eps, C_xx, (a_log + f_aa / eps).detach()) - g_bb = damping * softmin(eps, C_yy, (b_log + g_bb / eps).detach()) + f_aa = damping * softmin(eps, C_xx, (a_log + f_aa / eps).detach(), **kwargs_xx) + g_bb = damping * softmin(eps, C_yy, (b_log + g_bb / eps).detach(), **kwargs_yy) if debias: return f_aa, g_bb, g_ab, f_ba diff --git a/geomloss/sinkhorn_samples.py b/geomloss/sinkhorn_samples.py index 0bc3016..8042a7b 100644 --- a/geomloss/sinkhorn_samples.py +++ b/geomloss/sinkhorn_samples.py @@ -14,7 +14,7 @@ except: keops_available = False -from .utils import scal, squared_distances, distances +from .utils import scal, squared_distances, distances, ranges_from_ptr from .sinkhorn_divergence import epsilon_schedule, scaling_parameters from .sinkhorn_divergence import dampening, log_weights, sinkhorn_cost, sinkhorn_loop @@ -290,6 +290,71 @@ def softmin_online_lazytensor(eps, C_xy, h_y, p=2): return -eps * smin +def softmin_online_ranges(eps, C_xy, h_y, p=2,*, ranges): + r"""Soft-C-transform, implemented using symbolic KeOps LazyTensors. + + This routine implements the (soft-)C-transform + between dual vectors, which is the core computation for + Auction- and Sinkhorn-like optimal transport solvers. + + If `eps` is a float number, `C_xy = (x, y)` is a pair of (batched) + point clouds, encoded as (B, N, D) and (B, M, D) Tensors + and `h_y` encodes a dual potential :math:`h_j` that is supported by the points + :math:`y_j`'s, then `softmin_tensorized(eps, C_xy, h_y)` returns a dual potential + `f` for ":math:`f_i`", supported by the :math:`x_i`'s, that is equal to: + + .. math:: + f_i \gets - \varepsilon \log \sum_{j=1}^{\text{M}} \exp + \big[ h_j - \|x_i - y_j\|^p / p \varepsilon \big]~. + + For more detail, see e.g. Section 3.3 and Eq. (3.186) in Jean Feydy's PhD thesis. + + Args: + eps (float, positive): Temperature :math:`\varepsilon` for the Gibbs kernel + :math:`K_{i,j} = \exp(- \|x_i - y_j\|^p / p \varepsilon)`. + + C_xy (pair of (B, N, D), (B, M, D) Tensors): Point clouds :math:`x_i` + and :math:`y_j`, with a batch dimension. + + h_y ((B, M) Tensor): Vector of logarithmic "dual" values, with a batch dimension. + Most often, this vector will be computed as `h_y = b_log + g_j / eps`, + where `b_log` is a vector of log-weights :math:`\log(\beta_j)` + for the :math:`y_j`'s and :math:`g_j` is a dual vector + in the Sinkhorn algorithm, so that: + + .. math:: + f_i \gets - \varepsilon \log \sum_{j=1}^{\text{M}} \beta_j + \exp \tfrac{1}{\varepsilon} \big[ g_j - \|x_i - y_j\|^p / p \big]~. + + Returns: + (B, N) Tensor: Dual potential `f` of values :math:`f_i`, supported + by the points :math:`x_i`. + """ + x, y = C_xy # Retrieve our point clouds + B = x.shape[0] # Batch dimension + assert B == 1 + + # Encoding as batched KeOps LazyTensors: + x_i = LazyTensor(x.squeeze(0)[:, None, :]) # (N, 1, D) + y_j = LazyTensor(y.squeeze(0)[None, :, :]) # (1, M, D) + h_j = LazyTensor(h_y.squeeze(0)[None, :, None]) # (1, M, 1) + + # Cost matrix: + if p == 2: # Halved, squared Euclidean distance + C_ij = ((x_i - y_j) ** 2).sum(-1) / 2 # (B, N, M, 1) + + elif p == 1: # Simple Euclidean distance + C_ij = ((x_i - y_j) ** 2).sum(-1).sqrt() # (B, N, M, 1) + + else: + raise NotImplementedError() + + pre_reduction = (h_j - C_ij * torch.Tensor([1 / eps]).type_as(x)) + + # KeOps log-sum-exp reduction over the "M" dimension: + smin = pre_reduction.logsumexp(1, ranges=ranges).view(B, -1) + + return -eps * smin def lse_lazytensor(p, D, batchdims=(1,)): """This implementation is currently disabled.""" @@ -360,6 +425,8 @@ def sinkhorn_online( cost=None, debias=True, potentials=False, + ptr_x=None, + ptr_y=None, **kwargs, ): B, N, D = x.shape @@ -372,7 +439,8 @@ def sinkhorn_online( else: my_lse = lse_lazytensor(p, D, batchdims=(B,)) softmin = partial(softmin_online, log_conv=my_lse) - + elif ptr_x is not None or ptr_y is not None: + softmin = partial(softmin_online_ranges, p=p) else: if B > 1: raise ValueError( @@ -393,6 +461,18 @@ def sinkhorn_online( C_xx, C_yy = ((x, x.detach()), (y, y.detach())) if debias else (None, None) C_xy, C_yx = ((x, y.detach()), (y, x.detach())) + if ptr_x is not None and ptr_y is not None: + ranges_x, slices_x = ranges_from_ptr(ptr_x) + ranges_y, slices_y = ranges_from_ptr(ptr_y) + ranges_xy = (ranges_x, slices_x, ranges_y, ranges_y, slices_y, ranges_x) + ranges_yx = (ranges_y, slices_y, ranges_x, ranges_x, slices_x, ranges_y) + ranges_xx = (ranges_x, slices_x, ranges_x, ranges_x, slices_x, ranges_x) + ranges_yy = (ranges_y, slices_y, ranges_y, ranges_y, slices_y, ranges_y) + elif ptr_x is None and ptr_y is None : + ranges_xy = ranges_yx = ranges_xx = ranges_yy = None + else: + raise ValueError("ptr_x and ptr_y should be both None or both not None") + diameter, eps, eps_list, rho = scaling_parameters( x, y, p, blur, reach, diameter, scaling ) @@ -408,6 +488,10 @@ def sinkhorn_online( eps_list, rho, debias=debias, + batch_ranges_xx=ranges_xx, + batch_ranges_yy=ranges_yy, + batch_ranges_xy=ranges_xy, + batch_ranges_yx=ranges_yx, ) return sinkhorn_cost( @@ -422,6 +506,10 @@ def sinkhorn_online( batch=True, debias=debias, potentials=potentials, + batch_ranges_xx=ranges_xx, + batch_ranges_yy=ranges_yy, + batch_ranges_xy=ranges_xy, + batch_ranges_yx=ranges_yx, ) diff --git a/geomloss/utils.py b/geomloss/utils.py index b3b9e8c..7e8fb75 100644 --- a/geomloss/utils.py +++ b/geomloss/utils.py @@ -10,8 +10,22 @@ except: keops_available = False - -def scal(a, f, batch=False): +def segment_sum(x, batch): + """uses only torch primitives, avoids needing to import torch_scatter""" + if x.dim() == 2: + x = x.squeeze(0) + result = torch.zeros(batch.shape[0], *x.shape[1:], device=x.device) + index = torch.arange(batch.shape[0], device=x.device) + index = index.repeat_interleave(batch) + result.index_add_(0, index, x) + return result + +def scal(a, f, batch=False, ranges=None): + if ranges is not None: + ranges_a, *_ = ranges + batch_a = ranges_a[:, 1] - ranges_a[:, 0] + prod = a * f + return segment_sum(prod, batch_a) if batch: B = a.shape[0] return (a.reshape(B, -1) * f.reshape(B, -1)).sum(1) @@ -278,3 +292,16 @@ def softmin(a_log): ) # Act on dim 2 return -eps * h_y + + +######################## +# For batching + +def ranges_from_ptr(ptr_x): + """Converts a list of pointers to a list of ranges.""" + B, = ptr_x.shape + ranges_x = torch.stack([ptr_x[:-1],ptr_x[1:]], dim=-1) + slices = torch.arange(1, B + 1, device=ptr_x.device) + return ranges_x, slices + + diff --git a/test_scrip.py b/test_scrip.py index ac3d3e9..18e73ce 100644 --- a/test_scrip.py +++ b/test_scrip.py @@ -1,7 +1,7 @@ import torch from geomloss import SamplesLoss -device = "cuda" +device = "cpu" tdtype = torch.float @@ -40,3 +40,29 @@ # print(a, b) # print(c, d) print(torch.norm(A - B)) + + + +xs = [torch.randn((8, 2), dtype=torch.float, device=device), torch.randn((7, 2), dtype=torch.float, device=device), torch.randn((6, 2), dtype=torch.float, device=device),] +ys = [torch.randn((3, 2), dtype=torch.float, device=device), torch.randn((4, 2), dtype=torch.float, device=device), torch.randn((5, 2), dtype=torch.float, device=device),] + +L_online = SamplesLoss( + "sinkhorn", + p=p, + blur=0.5, + potentials=potential, + debias=False, + backend="online", +) + +distances = torch.as_tensor([L_online(x, y) for x, y in zip(xs, ys)], device=device) + +ptr_x = torch.tensor([0, 8, 15, 21], device=device) +ptr_y = torch.tensor([0, 3, 7, 12], device=device) +xs_batched = torch.cat(xs, dim=0) +ys_batched = torch.cat(ys, dim=0) + +distances_batched = L_online(xs_batched, ys_batched, ptr_x=ptr_x, ptr_y=ptr_y) + +# not sure where the small difference comes from. Is it just numerical error? +print( torch.norm(distances - distances_batched) )