Skip to content

Commit

Permalink
Add torch based MMD implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jsvetter committed Feb 6, 2024
1 parent 5443057 commit 3fa84ec
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions labproject/metrics/MMD_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch

from labproject.metrics.utils import register_metric


# NOTE: all tensors should be of shape (n_samples, n_features)


def rbf_kernel(x, y, bandwidth):
dist = torch.cdist(x, y)
return torch.exp(-(dist**2) / (2.0 * bandwidth**2))


def polynomial_kernel(x, y, degree, bias):
return (x @ y.t() + bias) ** degree


def linear_kernel(x, y):
return x @ y.t()


@register_metric("mmd_rbf")
def compute_rbf_mmd(x, y, bandwidth=1.0):
x_kernel = rbf_kernel(x, x, bandwidth)
y_kernel = rbf_kernel(y, y, bandwidth)
xy_kernel = rbf_kernel(x, y, bandwidth)
mmd = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
return mmd


@register_metric("mmd_polynomial")
def compute_polynomial_mmd(x, y, degree=2, bias=0):
x_kernel = polynomial_kernel(x, x, degree, bias)
y_kernel = polynomial_kernel(y, y, degree, bias)
xy_kernel = polynomial_kernel(x, y, degree, bias)
mmd = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
return mmd


@register_metric("mmd_linear_naive")
def compute_linear_mmd_naive(x, y):
x_kernel = linear_kernel(x, x)
y_kernel = linear_kernel(y, y)
xy_kernel = linear_kernel(x, y)
mmd = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
return mmd


@register_metric("mmd_linear")
def compute_linear_mmd(x, y):
delta = torch.mean(x, 0) - torch.mean(y, 0)
return torch.norm(delta, 2) ** 2

0 comments on commit 3fa84ec

Please sign in to comment.