Skip to content

Commit

Permalink
fix c2st-knn bug and add simple example
Browse files Browse the repository at this point in the history
  • Loading branch information
felixp8 authored Feb 5, 2024
1 parent 7454044 commit 528bb17
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion labproject/metrics/c2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
import inspect
from torch import ones, zeros, eye, sum, Tensor, tensor, allclose, manual_seed
from torch.distributions import MultivariateNormal, Normal
from sklearn.ensemble import RandomForestClassifier
Expand Down Expand Up @@ -332,7 +333,9 @@ def c2st_scores(
X = X.cpu().numpy()
Y = Y.cpu().numpy()

clf = clf_class(random_state=seed, **clf_kwargs)
if "random_state" in inspect.signature(clf_class.__init__).parameters.keys():
clf_kwargs["random_state"] = seed
clf = clf_class(**clf_kwargs)

# prepare data
data = np.concatenate((X, Y))
Expand Down Expand Up @@ -382,3 +385,15 @@ def test_optimal_c2st():
c2st = c2st_optimal(d1, d2, 100_000)
target = Normal(0.0, 1.0).cdf(tensor(mean_diff // 2))
assert allclose(c2st, target, atol=1e-3)


if __name__ == "__main__":
# Generate random samples
samples1 = torch.randn(100, 2)
samples2 = torch.randn(100, 2)

# Compute sliced wasserstein distance
c2st_nn_score = c2st_nn(samples1, samples2)
c2st_knn_score = c2st_knn(samples1, samples2)
c2st_rf_score = c2st_rf(samples1, samples2)
print(f"C2ST-NN: {c2st_nn_score}\nC2ST-KNN: {c2st_knn_score}\nC2ST-RF: {c2st_rf_score}")

0 comments on commit 528bb17

Please sign in to comment.