From 528bb174e7ceeb57c0969229bb510b68ebdcf63b Mon Sep 17 00:00:00 2001 From: Felix Pei <64850082+felixp8@users.noreply.github.com> Date: Mon, 5 Feb 2024 14:53:39 +0100 Subject: [PATCH] fix c2st-knn bug and add simple example --- labproject/metrics/c2st.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/labproject/metrics/c2st.py b/labproject/metrics/c2st.py index a58cf01..a8059da 100644 --- a/labproject/metrics/c2st.py +++ b/labproject/metrics/c2st.py @@ -2,6 +2,7 @@ import numpy as np import torch +import inspect from torch import ones, zeros, eye, sum, Tensor, tensor, allclose, manual_seed from torch.distributions import MultivariateNormal, Normal from sklearn.ensemble import RandomForestClassifier @@ -332,7 +333,9 @@ def c2st_scores( X = X.cpu().numpy() Y = Y.cpu().numpy() - clf = clf_class(random_state=seed, **clf_kwargs) + if "random_state" in inspect.signature(clf_class.__init__).parameters.keys(): + clf_kwargs["random_state"] = seed + clf = clf_class(**clf_kwargs) # prepare data data = np.concatenate((X, Y)) @@ -382,3 +385,15 @@ def test_optimal_c2st(): c2st = c2st_optimal(d1, d2, 100_000) target = Normal(0.0, 1.0).cdf(tensor(mean_diff // 2)) assert allclose(c2st, target, atol=1e-3) + + +if __name__ == "__main__": + # Generate random samples + samples1 = torch.randn(100, 2) + samples2 = torch.randn(100, 2) + + # Compute sliced wasserstein distance + c2st_nn_score = c2st_nn(samples1, samples2) + c2st_knn_score = c2st_knn(samples1, samples2) + c2st_rf_score = c2st_rf(samples1, samples2) + print(f"C2ST-NN: {c2st_nn_score}\nC2ST-KNN: {c2st_knn_score}\nC2ST-RF: {c2st_rf_score}")