Skip to content

Commit

Permalink
add DINOv2-FD metric
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Apr 3, 2024
1 parent c529006 commit 09af570
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ venv/
# tests' model weights
tests/weights/
tests/repos/
tests/datasets/

# ruff
.ruff_cache
Expand Down
2 changes: 2 additions & 0 deletions src/refiners/foundationals/dinov2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DINOv2_large_reg,
DINOv2_small,
DINOv2_small_reg,
preprocess,
)
from .vit import ViT

Expand All @@ -16,4 +17,5 @@
"DINOv2_small",
"DINOv2_small_reg",
"ViT",
"preprocess",
]
20 changes: 18 additions & 2 deletions src/refiners/foundationals/dinov2/dinov2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
import torch
from PIL import Image

from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.dinov2.vit import ViT

# TODO: add preprocessing logic like
# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77

def preprocess(img: Image.Image, dim: int = 224) -> torch.Tensor:
"""
Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation.
Note that this only resizes and normalizes the image, there is no center crop.
Args:
img: The image.
dim: The square dimension to resize the image. Typically 224 or 518.
Returns:
A float32 tensor with shape (3, dim, dim).
"""
img = img.convert("RGB").resize((dim, dim)) # type: ignore
t = image_to_tensor(img).squeeze()
return normalize(t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


class DINOv2_small(ViT):
Expand Down
117 changes: 117 additions & 0 deletions src/refiners/training_utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from pathlib import Path

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from refiners.foundationals import dinov2


def get_dinov2_representations(
model: dinov2.ViT,
dataloader: DataLoader[torch.Tensor],
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
"""
Get DINOV2 representations required to compute DINOv2-FD.
Args:
model: The DINOv2 model to use.
dataloader: A dataloader that returns batches of preprocessed images.
dtype: The dtype to use for the representations. Use float64 for good precision.
Returns:
A tensor with shape (batch, embedding_dim).
"""
r: list[torch.Tensor] = []
for batch in dataloader:
assert isinstance(batch, torch.Tensor)
batch_size = batch.shape[0]
assert batch.shape == (batch_size, 3, 224, 224)
batch = batch.to(model.device)

with torch.no_grad():
pred = model(batch)[:, 0] # only keep class embeddings

assert isinstance(pred, torch.Tensor)
assert pred.shape == (batch_size, model.embedding_dim)

r.append(pred.to(dtype))

return torch.cat(r)


def frechet_distance(reps_a: torch.Tensor, reps_b: torch.Tensor) -> float:
"""
Compute the Fréchet distance between two sets of representations.
Args:
reps_a: First set of representations (typically the reference). Shape (batch, N).
reps_a: Second set of representations (typically the test set). Shape (batch, N).
"""
assert reps_a.dim() == 2 and reps_b.dim() == 2, "representations must have shape (batch, N)"
assert reps_a.shape[1] == reps_b.shape[1], "representations must have the same dimension"

mean_a = torch.mean(reps_a, dim=0)
cov_a = torch.cov(reps_a.t())
mean_b = torch.mean(reps_b, dim=0)
cov_b = torch.cov(reps_b.t())

# The trace of the square root of a matrix is the sum of the square roots of its eigenvalues.
trace = (torch.linalg.eigvals(cov_a.mm(cov_b)) ** 0.5).real.sum() # type: ignore
assert isinstance(trace, torch.Tensor)

score = ((mean_a - mean_b) ** 2).sum() + cov_a.trace() + cov_b.trace() - 2 * trace
return score.item()


class DinoDataset(Dataset[torch.Tensor]):
def __init__(self, path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
self.image_paths = sorted(path.glob("*.png"))

def __len__(self) -> int:
return len(self.image_paths)

def __getitem__(self, i: int) -> torch.Tensor:
path = self.image_paths[i]
img = Image.open(path) # type: ignore
return dinov2.preprocess(img)


def dinov2_frechet_distance(
dataset_a: Dataset[torch.Tensor] | str | Path,
dataset_b: Dataset[torch.Tensor] | str | Path,
model: dinov2.ViT,
batch_size: int = 64,
dtype: torch.dtype = torch.float64,
) -> float:
"""
Compute DINOv2-based Fréchet Distance between two datasets.
There may be small discrepancies with other implementations due to the fact that DINOv2 in Refiners
uses the new style interpolation whereas DINOv2-FD historically uses the legacy implementation
(see https://github.com/facebookresearch/dinov2/pull/378)
Args:
dataset_a: First dataset (typically the reference). Can also be a path to a directory of PNG images.
If a dataset is passed, it must preprocess the data using `dinov2.preprocess`.
dataset_b: Second dataset (typically the test set). See `dataset_a` for details. Size can be different.
model: The DINOv2 model to use.
batch_size: The batch size to use.
dtype: The dtype to use for the representations. Use float64 for good precision.
"""

if not isinstance(dataset_a, Dataset):
dataset_a = DinoDataset(dataset_a)
if not isinstance(dataset_b, Dataset):
dataset_b = DinoDataset(dataset_b)

dataloader_a = DataLoader(dataset_a, batch_size=batch_size, shuffle=False)
dataloader_b = DataLoader(dataset_b, batch_size=batch_size, shuffle=False)

reps_a = get_dinov2_representations(model, dataloader_a, dtype)
reps_b = get_dinov2_representations(model, dataloader_b, dtype)

return frechet_distance(reps_a, reps_b)
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

PARENT_PATH = Path(__file__).parent

collect_ignore = ["weights", "repos", "datasets"]
collect_ignore_glob = ["*_ref"]


@fixture(scope="session")
def test_device() -> torch.device:
Expand All @@ -21,6 +24,12 @@ def test_weights_path() -> Path:
return Path(from_env) if from_env else PARENT_PATH / "weights"


@fixture(scope="session")
def test_datasets_path() -> Path:
from_env = os.getenv("REFINERS_TEST_DATASETS_DIR")
return Path(from_env) if from_env else PARENT_PATH / "datasets"


@fixture(scope="session")
def test_repos_path() -> Path:
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
Expand Down
67 changes: 67 additions & 0 deletions tests/training_utils/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pathlib import Path
from warnings import warn

import pytest
import torch
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10 # type: ignore

from refiners.foundationals import dinov2
from refiners.training_utils.metrics import dinov2_frechet_distance


class CifarDataset(Dataset[torch.Tensor]):
def __init__(self, ds: Dataset[list[torch.Tensor]], max_len: int = 512) -> None:
self.ds = ds
ds_length = len(self.ds) # type: ignore
self.length = min(ds_length, max_len)

def __len__(self) -> int:
return self.length

def __getitem__(self, i: int) -> torch.Tensor:
return self.ds[i][0]


@pytest.fixture(scope="module")
def dinov2_l(
test_weights_path: Path,
test_device: torch.device,
) -> dinov2.DINOv2_large:
weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)

model = dinov2.DINOv2_large(device=test_device)
model.load_from_safetensors(weights)
return model


def test_dinov2_frechet_distance(test_datasets_path: Path, dinov2_l: dinov2.DINOv2_large) -> None:
path = str(test_datasets_path / "CIFAR10")

ds_train = CifarDataset(
CIFAR10(
root=path,
train=True,
download=True,
transform=dinov2.preprocess,
)
)

ds_test = CifarDataset(
CIFAR10(
root=path,
train=False,
download=True,
transform=dinov2.preprocess,
)
)

# Computed using dgm-eval (https://github.com/layer6ai-labs/dgm-eval)
# with interpolate_offset=0 and random_sample=False.
expected_d = 837.978

d = dinov2_frechet_distance(ds_train, ds_test, dinov2_l)
assert expected_d - 1e-2 < d < expected_d + 1e-2

0 comments on commit 09af570

Please sign in to comment.