From 68f08cab7600294a36d2cac1ca153d8a581d19d4 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Tue, 6 Feb 2024 11:40:47 +0100 Subject: [PATCH] metrics register and getter functions --- docs/notebooks/compare_swd_sinkhorn.ipynb | 31 +++++++----- labproject/metrics/__init__.py | 16 ++++-- .../metrics/gaussian_squared_wasserstein.py | 2 + labproject/metrics/sliced_wasserstein.py | 2 + labproject/metrics/utils.py | 49 +++++++++++++++++++ labproject/metrics/wasserstein_sinkhorn.py | 2 + 6 files changed, 88 insertions(+), 14 deletions(-) create mode 100644 labproject/metrics/utils.py diff --git a/docs/notebooks/compare_swd_sinkhorn.ipynb b/docs/notebooks/compare_swd_sinkhorn.ipynb index 0dec899..3aec1d0 100644 --- a/docs/notebooks/compare_swd_sinkhorn.ipynb +++ b/docs/notebooks/compare_swd_sinkhorn.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 161, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -21,30 +21,39 @@ }, { "cell_type": "code", - "execution_count": 162, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'gaussian_squared_wasserstein_distance': , 'sliced_wasserstein_distance': , 'wasserstein_sinkhorn': , 'wasserstein_gauss_squared': , 'sliced_wasserstein': }\n" + ] + } + ], "source": [ "import torch\n", - "from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance\n", - "from labproject.metrics.wasserstein_sinkhorn import sinkhorn_loss\n", - "from labproject.metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance\n", + "from labproject.metrics import get_metric,METRICS\n", + "print(METRICS)\n", "\n", - "\n" + "sinkhorn_loss = get_metric('wasserstein_sinkhorn')\n", + "sliced_wasserstein_distance = get_metric('sliced_wasserstein')\n", + "gaussian_squared_w2_distance = get_metric('wasserstein_gauss_squared')\n" ] }, { "cell_type": "code", - "execution_count": 163, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0.07661963999271393\n", - "0.1118786375773585\n", - "0.06996694207191467\n" + "0.09820912033319473\n", + "0.12430790905907681\n", + "0.10992895066738129\n" ] } ], diff --git a/labproject/metrics/__init__.py b/labproject/metrics/__init__.py index 8bdec7a..03d44f7 100644 --- a/labproject/metrics/__init__.py +++ b/labproject/metrics/__init__.py @@ -5,9 +5,19 @@ 2. The function should be well-documented, including type hints. 3. The function should be tested with a simple example. 4. Add an assert at the beginning for shape checking (N,D), see examples. +5. Register the function by importing `labrpoject.metrics.utils.regiter_metric` and give your function a meaningful name. """ -from labproject.metrics.gaussian_kl import gaussian_kl_divergence -from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance +import importlib +import pkgutil -METRICS = {} +# Get the list of all submodules in the metrics package +package_name = "labproject.metrics" +package = importlib.import_module(package_name) +module_names = [name for _, name, _ in pkgutil.iter_modules(package.__path__)] + +# Import all the metrics modules that have a register_metric decorator +for module_name in module_names: + module = importlib.import_module(f"{package_name}.{module_name}") + if hasattr(module, "register_metric"): + globals().update(module.__dict__) diff --git a/labproject/metrics/gaussian_squared_wasserstein.py b/labproject/metrics/gaussian_squared_wasserstein.py index c4b905a..aa41796 100644 --- a/labproject/metrics/gaussian_squared_wasserstein.py +++ b/labproject/metrics/gaussian_squared_wasserstein.py @@ -1,8 +1,10 @@ import torch from torch import Tensor import scipy +from labproject.metrics.utils import register_metric +@register_metric("wasserstein_gauss_squared") def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor) -> Tensor: r""" Compute the squared Wasserstein distance between Gaussian approximations of real and fake samples. diff --git a/labproject/metrics/sliced_wasserstein.py b/labproject/metrics/sliced_wasserstein.py index cd2418c..d935247 100644 --- a/labproject/metrics/sliced_wasserstein.py +++ b/labproject/metrics/sliced_wasserstein.py @@ -3,8 +3,10 @@ import torch from torch import Tensor +from labproject.metrics.utils import register_metric +@register_metric("sliced_wasserstein") def sliced_wasserstein_distance( encoded_samples: Tensor, distribution_samples: Tensor, diff --git a/labproject/metrics/utils.py b/labproject/metrics/utils.py new file mode 100644 index 0000000..ca3f861 --- /dev/null +++ b/labproject/metrics/utils.py @@ -0,0 +1,49 @@ +from typing import Optional +import torch +import functools +import warnings + +METRICS = {} + + +def register_metric(name: str) -> callable: + r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape. + + Args: + name (str): name of supported metric + + Returns: + callable: metric function wrapper + + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Call the original function + metric = func(*args, **kwargs) + return metric + + METRICS[name] = wrapper + return wrapper + + return decorator + + +def get_metric(name: str) -> callable: + r"""Get a metric by name + + Args: + name (str): Name of the metric + + Returns: + callable: metric function + + Example: + + from labproject.metrics.utils import get_metric + wasserstein_sinkhorn = get_metric("wasserstein_sinkhorn") + dist = wasserstein_sinkhorn(real_samples, fake_samples, epsilon=1e-3, niter=1000, p=2) + """ + assert name in METRICS, f"Distribution {name} not found, please register it first " + return METRICS[name] diff --git a/labproject/metrics/wasserstein_sinkhorn.py b/labproject/metrics/wasserstein_sinkhorn.py index aa2fb11..9ebad36 100644 --- a/labproject/metrics/wasserstein_sinkhorn.py +++ b/labproject/metrics/wasserstein_sinkhorn.py @@ -1,4 +1,5 @@ import torch +from labproject.metrics.utils import register_metric def sinkhorn_algorithm( @@ -62,6 +63,7 @@ def MC(u, v): return cost, transport +@register_metric("wasserstein_sinkhorn") def sinkhorn_loss( x: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-3, niter: int = 1000, p: int = 2 ):