Skip to content

Commit

Permalink
metrics register and getter functions
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Feb 6, 2024
1 parent 07a879f commit 68f08ca
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 14 deletions.
31 changes: 20 additions & 11 deletions docs/notebooks/compare_swd_sinkhorn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 161,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -21,30 +21,39 @@
},
{
"cell_type": "code",
"execution_count": 162,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'gaussian_squared_wasserstein_distance': <function gaussian_squared_w2_distance at 0x7fe32954a5e0>, 'sliced_wasserstein_distance': <function sliced_wasserstein_distance at 0x7fe32954a700>, 'wasserstein_sinkhorn': <function sinkhorn_loss at 0x7fe32954ad30>, 'wasserstein_gauss_squared': <function gaussian_squared_w2_distance at 0x7fe31a0b6c10>, 'sliced_wasserstein': <function sliced_wasserstein_distance at 0x7fe32954ae50>}\n"
]
}
],
"source": [
"import torch\n",
"from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance\n",
"from labproject.metrics.wasserstein_sinkhorn import sinkhorn_loss\n",
"from labproject.metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance\n",
"from labproject.metrics import get_metric,METRICS\n",
"print(METRICS)\n",
"\n",
"\n"
"sinkhorn_loss = get_metric('wasserstein_sinkhorn')\n",
"sliced_wasserstein_distance = get_metric('sliced_wasserstein')\n",
"gaussian_squared_w2_distance = get_metric('wasserstein_gauss_squared')\n"
]
},
{
"cell_type": "code",
"execution_count": 163,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.07661963999271393\n",
"0.1118786375773585\n",
"0.06996694207191467\n"
"0.09820912033319473\n",
"0.12430790905907681\n",
"0.10992895066738129\n"
]
}
],
Expand Down
16 changes: 13 additions & 3 deletions labproject/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
2. The function should be well-documented, including type hints.
3. The function should be tested with a simple example.
4. Add an assert at the beginning for shape checking (N,D), see examples.
5. Register the function by importing `labrpoject.metrics.utils.regiter_metric` and give your function a meaningful name.
"""

from labproject.metrics.gaussian_kl import gaussian_kl_divergence
from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance
import importlib
import pkgutil

METRICS = {}
# Get the list of all submodules in the metrics package
package_name = "labproject.metrics"
package = importlib.import_module(package_name)
module_names = [name for _, name, _ in pkgutil.iter_modules(package.__path__)]

# Import all the metrics modules that have a register_metric decorator
for module_name in module_names:
module = importlib.import_module(f"{package_name}.{module_name}")
if hasattr(module, "register_metric"):
globals().update(module.__dict__)
2 changes: 2 additions & 0 deletions labproject/metrics/gaussian_squared_wasserstein.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
from torch import Tensor
import scipy
from labproject.metrics.utils import register_metric


@register_metric("wasserstein_gauss_squared")
def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor) -> Tensor:
r"""
Compute the squared Wasserstein distance between Gaussian approximations of real and fake samples.
Expand Down
2 changes: 2 additions & 0 deletions labproject/metrics/sliced_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import torch
from torch import Tensor
from labproject.metrics.utils import register_metric


@register_metric("sliced_wasserstein")
def sliced_wasserstein_distance(
encoded_samples: Tensor,
distribution_samples: Tensor,
Expand Down
49 changes: 49 additions & 0 deletions labproject/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Optional
import torch
import functools
import warnings

METRICS = {}


def register_metric(name: str) -> callable:
r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.
Args:
name (str): name of supported metric
Returns:
callable: metric function wrapper
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Call the original function
metric = func(*args, **kwargs)
return metric

METRICS[name] = wrapper
return wrapper

return decorator


def get_metric(name: str) -> callable:
r"""Get a metric by name
Args:
name (str): Name of the metric
Returns:
callable: metric function
Example:
from labproject.metrics.utils import get_metric
wasserstein_sinkhorn = get_metric("wasserstein_sinkhorn")
dist = wasserstein_sinkhorn(real_samples, fake_samples, epsilon=1e-3, niter=1000, p=2)
"""
assert name in METRICS, f"Distribution {name} not found, please register it first "
return METRICS[name]
2 changes: 2 additions & 0 deletions labproject/metrics/wasserstein_sinkhorn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from labproject.metrics.utils import register_metric


def sinkhorn_algorithm(
Expand Down Expand Up @@ -62,6 +63,7 @@ def MC(u, v):
return cost, transport


@register_metric("wasserstein_sinkhorn")
def sinkhorn_loss(
x: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-3, niter: int = 1000, p: int = 2
):
Expand Down

0 comments on commit 68f08ca

Please sign in to comment.