Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712804470
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Jan 7, 2025
1 parent a21244b commit 47e98cf
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions swirl_dynamics/projects/debiasing/optimal_transport/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
[1] Cuturi, Marco. "Sinkhorn distances: Lightspeed computation of optimal
transport." Advances in neural information processing systems 26 (2013).
[2] Pooladian, Aram-Alexandre, and Niles-Weed, Jonathan. "Entropic estimation of
optimal transport maps." arXiv preprint arXiv:2109.12004 (2021).
"""

from typing import Callable, NamedTuple
Expand Down Expand Up @@ -261,3 +264,103 @@ def _log_gibbs_kernel(self, u: Array, v: Array, cost_matrix: Array) -> Array:
kernel = -cost_matrix + u[:, None] + v[None, :]
kernel /= self.epsilon
return kernel

def transport_fn(
self, potential: Array, y: Array, weights: Array | None = None
) -> Callable[[Array], Array]:
r"""Transport functions using the formulation in the proposition 2 of [2].
We use the fact that the transport function can be written as:
T(x) = x - 0.5 * \nabla(f_{\epsilon}(x)),
where f_{\epsilon}(x) is the potential computed using the Eq. 9 in [2].
Args:
potential: The potential g (or f) given by the Sinkhorn algorithm.
y: Collection of point of set B, associated with the potential.
weights: Quadrature weights (or marginal densities) for the y points.
Returns:
A function that computes the transport function at a given x.
"""

# Computes the potential of set A.
f_eps = lambda x: self._potential_fn(x, potential, y, weights)

# Here we assume that the cost is the Euclidean distance. In comparison with
# [2] we don't have a 1/2 factor in the definition of the distance, so we
# need to divide by 2.
return jax.vmap(
lambda x: x - 0.5 * jax.grad(f_eps)(x), in_axes=0, out_axes=0
)

def _potential_fn(
self,
x: Array,
potential: Array,
y: Array,
weights: Array | None = None,
) -> Array:
r"""Callback function to compute the potential.
Here we use the formula in Proposition 2 of [2]:
f_{\epsilon}(x) = - \epsilon \log (\sum_{i}
exp ( g_{\epsilon}(y_i) - dist(x, y_i) ) b_i
here b_i is the marginal density of set B (associated with y_i).
Args:
x: Collection of points of set A where f_{\epsilon} will be computed at.
potential: Potential of set B.
y: Collection of point of set B.
weights: Quadrature weights (or marginal densities) for the y points.
Returns:
The potential of set A.
"""
x = jnp.atleast_2d(x)

if weights is None:
num_y = y.shape[0]
weights = jnp.ones((num_y,)) / num_y

if x.shape[-1] != y.shape[-1]:
raise ValueError(
"x and y should have the same feature dimension, but"
f" they have shape {x.shape[-1]}, {y.shape[-1]}, respectively."
)

# Computes cost matrix with respect to the current x.
cost = jnp.squeeze(self._compute_cost(x, y))
z = (potential - cost) / self.epsilon
lse = -self.epsilon * jax.scipy.special.logsumexp(z, b=weights, axis=-1)
return jnp.squeeze(lse)

def transport_fn_direct(
self, potential: Array, y: Array, weights: Array | None
) -> Callable[[Array], Array]:
"""Transport directly (not very stable). Using the formulas in [1]."""
if potential.ndim != 1:
raise ValueError(
"The potential should be a vector, but its dimension are not one,"
f" instead {y.ndim}"
)
if potential.shape[0] != y.shape[0]:
raise ValueError(
"We assume that the potential comes from solving Sinkhorn, but"
f" potential.shape[0] != y.shape[0]: {potential.shape}, {y.shape}"
)

if not weights:
num_y = y.shape[0]
weights = jnp.ones((num_y,)) / num_y

def _transport_direct(x: Array) -> Array:
# The dimension should be (1, num_y)
cost = jnp.squeeze(self._compute_cost(x, y))
z = jnp.exp((potential - cost) / self.epsilon) * weights
return jnp.sum(y * z[:, None], axis=-1)/jnp.sum(z)

return _transport_direct

0 comments on commit 47e98cf

Please sign in to comment.