diff --git a/.gitignore b/.gitignore index a8e9f40c3..99c137ca5 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ venv/ # tests' model weights tests/weights/ tests/repos/ +tests/datasets/ # ruff .ruff_cache diff --git a/src/refiners/foundationals/dinov2/__init__.py b/src/refiners/foundationals/dinov2/__init__.py index 91cfa794f..5ef56d9d2 100644 --- a/src/refiners/foundationals/dinov2/__init__.py +++ b/src/refiners/foundationals/dinov2/__init__.py @@ -5,6 +5,7 @@ DINOv2_large_reg, DINOv2_small, DINOv2_small_reg, + preprocess, ) from .vit import ViT @@ -16,4 +17,5 @@ "DINOv2_small", "DINOv2_small_reg", "ViT", + "preprocess", ] diff --git a/src/refiners/foundationals/dinov2/dinov2.py b/src/refiners/foundationals/dinov2/dinov2.py index 13d598fd2..a9840074f 100644 --- a/src/refiners/foundationals/dinov2/dinov2.py +++ b/src/refiners/foundationals/dinov2/dinov2.py @@ -1,9 +1,25 @@ import torch +from PIL import Image +from refiners.fluxion.utils import image_to_tensor, normalize from refiners.foundationals.dinov2.vit import ViT -# TODO: add preprocessing logic like -# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77 + +def preprocess(img: Image.Image, dim: int = 224) -> torch.Tensor: + """ + Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation. + Note that this only resizes and normalizes the image, there is no center crop. + + Args: + img: The image. + dim: The square dimension to resize the image. Typically 224 or 518. + + Returns: + A float32 tensor with shape (3, dim, dim). + """ + img = img.convert("RGB").resize((dim, dim)) # type: ignore + t = image_to_tensor(img).squeeze() + return normalize(t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) class DINOv2_small(ViT): diff --git a/src/refiners/training_utils/metrics.py b/src/refiners/training_utils/metrics.py new file mode 100644 index 000000000..3c26df249 --- /dev/null +++ b/src/refiners/training_utils/metrics.py @@ -0,0 +1,117 @@ +from pathlib import Path + +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset + +from refiners.foundationals import dinov2 + + +def get_dinov2_representations( + model: dinov2.ViT, + dataloader: DataLoader[torch.Tensor], + dtype: torch.dtype = torch.float64, +) -> torch.Tensor: + """ + Get DINOV2 representations required to compute DINOv2-FD. + + Args: + model: The DINOv2 model to use. + dataloader: A dataloader that returns batches of preprocessed images. + dtype: The dtype to use for the representations. Use float64 for good precision. + + Returns: + A tensor with shape (batch, embedding_dim). + """ + r: list[torch.Tensor] = [] + for batch in dataloader: + assert isinstance(batch, torch.Tensor) + batch_size = batch.shape[0] + assert batch.shape == (batch_size, 3, 224, 224) + batch = batch.to(model.device) + + with torch.no_grad(): + pred = model(batch)[:, 0] # only keep class embeddings + + assert isinstance(pred, torch.Tensor) + assert pred.shape == (batch_size, model.embedding_dim) + + r.append(pred.to(dtype)) + + return torch.cat(r) + + +def frechet_distance(reps_a: torch.Tensor, reps_b: torch.Tensor) -> float: + """ + Compute the Fréchet distance between two sets of representations. + + Args: + reps_a: First set of representations (typically the reference). Shape (batch, N). + reps_a: Second set of representations (typically the test set). Shape (batch, N). + """ + assert reps_a.dim() == 2 and reps_b.dim() == 2, "representations must have shape (batch, N)" + assert reps_a.shape[1] == reps_b.shape[1], "representations must have the same dimension" + + mean_a = torch.mean(reps_a, dim=0) + cov_a = torch.cov(reps_a.t()) + mean_b = torch.mean(reps_b, dim=0) + cov_b = torch.cov(reps_b.t()) + + # The trace of the square root of a matrix is the sum of the square roots of its eigenvalues. + trace = (torch.linalg.eigvals(cov_a.mm(cov_b)) ** 0.5).real.sum() # type: ignore + assert isinstance(trace, torch.Tensor) + + score = ((mean_a - mean_b) ** 2).sum() + cov_a.trace() + cov_b.trace() - 2 * trace + return score.item() + + +class DinoDataset(Dataset[torch.Tensor]): + def __init__(self, path: str | Path) -> None: + if isinstance(path, str): + path = Path(path) + self.image_paths = sorted(path.glob("*.png")) + + def __len__(self) -> int: + return len(self.image_paths) + + def __getitem__(self, i: int) -> torch.Tensor: + path = self.image_paths[i] + img = Image.open(path) # type: ignore + return dinov2.preprocess(img) + + +def dinov2_frechet_distance( + dataset_a: Dataset[torch.Tensor] | str | Path, + dataset_b: Dataset[torch.Tensor] | str | Path, + model: dinov2.ViT, + batch_size: int = 64, + dtype: torch.dtype = torch.float64, +) -> float: + """ + Compute DINOv2-based Fréchet Distance between two datasets. + + There may be small discrepancies with other implementations due to the fact that DINOv2 in Refiners + uses the new style interpolation whereas DINOv2-FD historically uses the legacy implementation + (see https://github.com/facebookresearch/dinov2/pull/378) + + Args: + dataset_a: First dataset (typically the reference). Can also be a path to a directory of PNG images. + If a dataset is passed, it must preprocess the data using `dinov2.preprocess`. + dataset_b: Second dataset (typically the test set). See `dataset_a` for details. Size can be different. + model: The DINOv2 model to use. + batch_size: The batch size to use. + dtype: The dtype to use for the representations. Use float64 for good precision. + """ + + if not isinstance(dataset_a, Dataset): + dataset_a = DinoDataset(dataset_a) + if not isinstance(dataset_b, Dataset): + dataset_b = DinoDataset(dataset_b) + + dataloader_a = DataLoader(dataset_a, batch_size=batch_size, shuffle=False) + dataloader_b = DataLoader(dataset_b, batch_size=batch_size, shuffle=False) + + reps_a = get_dinov2_representations(model, dataloader_a, dtype) + reps_b = get_dinov2_representations(model, dataloader_b, dtype) + + return frechet_distance(reps_a, reps_b) diff --git a/tests/conftest.py b/tests/conftest.py index bcd3a69e3..ba56acb12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,9 @@ PARENT_PATH = Path(__file__).parent +collect_ignore = ["weights", "repos", "datasets"] +collect_ignore_glob = ["*_ref"] + @fixture(scope="session") def test_device() -> torch.device: @@ -21,6 +24,12 @@ def test_weights_path() -> Path: return Path(from_env) if from_env else PARENT_PATH / "weights" +@fixture(scope="session") +def test_datasets_path() -> Path: + from_env = os.getenv("REFINERS_TEST_DATASETS_DIR") + return Path(from_env) if from_env else PARENT_PATH / "datasets" + + @fixture(scope="session") def test_repos_path() -> Path: from_env = os.getenv("REFINERS_TEST_REPOS_DIR") diff --git a/tests/training_utils/test_metrics.py b/tests/training_utils/test_metrics.py new file mode 100644 index 000000000..9dcaf0a7b --- /dev/null +++ b/tests/training_utils/test_metrics.py @@ -0,0 +1,67 @@ +from pathlib import Path +from warnings import warn + +import pytest +import torch +from torch.utils.data import Dataset +from torchvision.datasets import CIFAR10 # type: ignore + +from refiners.foundationals import dinov2 +from refiners.training_utils.metrics import dinov2_frechet_distance + + +class CifarDataset(Dataset[torch.Tensor]): + def __init__(self, ds: Dataset[list[torch.Tensor]], max_len: int = 512) -> None: + self.ds = ds + ds_length = len(self.ds) # type: ignore + self.length = min(ds_length, max_len) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, i: int) -> torch.Tensor: + return self.ds[i][0] + + +@pytest.fixture(scope="module") +def dinov2_l( + test_weights_path: Path, + test_device: torch.device, +) -> dinov2.DINOv2_large: + weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors" + if not weights.is_file(): + warn(f"could not find weights at {weights}, skipping") + pytest.skip(allow_module_level=True) + + model = dinov2.DINOv2_large(device=test_device) + model.load_from_safetensors(weights) + return model + + +def test_dinov2_frechet_distance(test_datasets_path: Path, dinov2_l: dinov2.DINOv2_large) -> None: + path = str(test_datasets_path / "CIFAR10") + + ds_train = CifarDataset( + CIFAR10( + root=path, + train=True, + download=True, + transform=dinov2.preprocess, + ) + ) + + ds_test = CifarDataset( + CIFAR10( + root=path, + train=False, + download=True, + transform=dinov2.preprocess, + ) + ) + + # Computed using dgm-eval (https://github.com/layer6ai-labs/dgm-eval) + # with interpolate_offset=0 and random_sample=False. + expected_d = 837.978 + + d = dinov2_frechet_distance(ds_train, ds_test, dinov2_l) + assert expected_d - 1e-2 < d < expected_d + 1e-2