Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimal c2st #18

Merged
merged 4 commits into from
Feb 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions labproject/metrics/c2st.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# from sbi: https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py

from typing import Any, Dict, Optional, Literal

import numpy as np
import torch
from torch import ones, zeros, eye, sum, Tensor, tensor, allclose, manual_seed
from torch.distributions import MultivariateNormal, Normal
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from torch import Tensor


# from sbi: https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py


def c2st_nn(
Expand Down Expand Up @@ -341,3 +343,42 @@ def c2st_scores(
scores = cross_val_score(clf, data, target, cv=shuffle, scoring=metric, verbose=verbosity)

return scores


def c2st_optimal(density1: Any, density2: Any, n_monte_carlo: int = 10_000) -> Tensor:
r"""Return the c2st that can be achieved by an optimal classifier.

This requires that both densities have `.log_prob()` functions.

Args:
density1: The first density. Must have `.sample()` and `.log_prob()`.
density2: The second density. Must have `.sample()` and `.log_prob()`.
n_monte_carlo: Number of Monte-Carlo samples that the computation is based on.

Returns:
torch.tensor containing the closed-form c2st (between 0.5 and 1.0).
"""
assert getattr(density1, "log_prob", None), "density1 has no `.log_prob()`"
assert getattr(density2, "log_prob", None), "density1 has no `.log_prob()`"

d1_samples = density1.sample((n_monte_carlo // 2,))
d2_samples = density2.sample((n_monte_carlo // 2,))

density_ratios1 = density1.log_prob(d1_samples) >= density2.log_prob(d1_samples)
density_ratios2 = density1.log_prob(d2_samples) < density2.log_prob(d2_samples)

return (sum(density_ratios1) + sum(density_ratios2)) / n_monte_carlo


def test_optimal_c2st():
"""Tests the c2st on 1D Gaussians against the cdf of that Gaussian."""
_ = manual_seed(0)
dim = 1
mean_diff = 4.0

d1 = MultivariateNormal(0.0 * ones(dim), eye(dim))
d2 = MultivariateNormal(mean_diff * ones(dim), eye(dim))

c2st = c2st_optimal(d1, d2, 100_000)
target = Normal(0.0, 1.0).cdf(tensor(mean_diff // 2))
assert allclose(c2st, target, atol=1e-3)