-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
metrics register and getter functions
- Loading branch information
Showing
6 changed files
with
88 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import Optional | ||
import torch | ||
import functools | ||
import warnings | ||
|
||
METRICS = {} | ||
|
||
|
||
def register_metric(name: str) -> callable: | ||
r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape. | ||
Args: | ||
name (str): name of supported metric | ||
Returns: | ||
callable: metric function wrapper | ||
""" | ||
|
||
def decorator(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
# Call the original function | ||
metric = func(*args, **kwargs) | ||
return metric | ||
|
||
METRICS[name] = wrapper | ||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
def get_metric(name: str) -> callable: | ||
r"""Get a metric by name | ||
Args: | ||
name (str): Name of the metric | ||
Returns: | ||
callable: metric function | ||
Example: | ||
from labproject.metrics.utils import get_metric | ||
wasserstein_sinkhorn = get_metric("wasserstein_sinkhorn") | ||
dist = wasserstein_sinkhorn(real_samples, fake_samples, epsilon=1e-3, niter=1000, p=2) | ||
""" | ||
assert name in METRICS, f"Distribution {name} not found, please register it first " | ||
return METRICS[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters