From dc3a649cc387b1b7b3a57822c6d22f3927a25625 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Mon, 5 Feb 2024 10:31:48 +0100 Subject: [PATCH 1/3] Add optimal c2st --- labproject/metrics/c2st.py | 42 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 labproject/metrics/c2st.py diff --git a/labproject/metrics/c2st.py b/labproject/metrics/c2st.py new file mode 100644 index 0000000..7d5fd38 --- /dev/null +++ b/labproject/metrics/c2st.py @@ -0,0 +1,42 @@ +from typing import Any + +from torch import ones, zeros, eye, sum, Tensor, tensor, allclose, manual_seed +from torch.distributions import MultivariateNormal, Normal + + +def c2st_optimal(density1: Any, density2: Any, n_monte_carlo: int = 10_000) -> Tensor: + """Return the c2st that can be achieved by an optimal classifier. + + This requires that both densities have `.log_prob()` functions. + + Args: + density1: The first density. Must have `.sample()` and `.log_prob()`. + density2: The second density. Must have `.sample()` and `.log_prob()`. + + Returns: + The closed-form c2st (between 0.5 and 1.0). + """ + assert getattr(density1, "log_prob", None), "density1 has no `.log_prob()`" + assert getattr(density2, "log_prob", None), "density1 has no `.log_prob()`" + + d1_samples = density1.sample((n_monte_carlo // 2,)) + d2_samples = density2.sample((n_monte_carlo // 2,)) + + density_ratios1 = density1.log_prob(d1_samples) >= density2.log_prob(d1_samples) + density_ratios2 = density1.log_prob(d2_samples) < density2.log_prob(d2_samples) + + return (sum(density_ratios1) + sum(density_ratios2)) / n_monte_carlo + + +def test_optimal_c2st(): + """Tests the c2st on 1D Gaussians against the cdf of that Gaussian.""" + _ = manual_seed(0) + dim = 1 + mean_diff = 4.0 + + d1 = MultivariateNormal(0.0 * ones(dim), eye(dim)) + d2 = MultivariateNormal(mean_diff * ones(dim), eye(dim)) + + c2st = c2st_optimal(d1, d2, 100_000) + target = Normal(0.0, 1.0).cdf(tensor(mean_diff // 2)) + assert allclose(c2st, target, atol=1e-3) From e4b26fb018e67d5f248a5cf3c5c350c91b424495 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Mon, 5 Feb 2024 10:37:47 +0100 Subject: [PATCH 2/3] Add docstring for number of monte carlo samples --- labproject/metrics/c2st.py | 1 + 1 file changed, 1 insertion(+) diff --git a/labproject/metrics/c2st.py b/labproject/metrics/c2st.py index 7d5fd38..80855ff 100644 --- a/labproject/metrics/c2st.py +++ b/labproject/metrics/c2st.py @@ -12,6 +12,7 @@ def c2st_optimal(density1: Any, density2: Any, n_monte_carlo: int = 10_000) -> T Args: density1: The first density. Must have `.sample()` and `.log_prob()`. density2: The second density. Must have `.sample()` and `.log_prob()`. + n_monte_carlo: Number of Monte-Carlo samples that the computation is based on. Returns: The closed-form c2st (between 0.5 and 1.0). From 3976d200dc5df4cbf8aaad6cc69b2bdffafa9e46 Mon Sep 17 00:00:00 2001 From: Felix Pei <64850082+felixp8@users.noreply.github.com> Date: Mon, 5 Feb 2024 11:38:03 +0100 Subject: [PATCH 3/3] fix typo in imports --- labproject/metrics/c2st.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/labproject/metrics/c2st.py b/labproject/metrics/c2st.py index d05112d..703ea74 100644 --- a/labproject/metrics/c2st.py +++ b/labproject/metrics/c2st.py @@ -12,10 +12,6 @@ # from sbi: https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py -from typing import Any - -from torch import Tensor - def c2st_nn( X: Tensor,