Skip to content

Commit 2e700c1

Browse files
committed
Added test to check if torch implementation matches reference numpy implementation
1 parent 81c0ccb commit 2e700c1

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

test/torch/differential_privacy/test_pate.py

+23
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,26 @@ def test_base_dataset_torch():
3737
)
3838

3939
assert data_dep_eps < data_ind_eps
40+
41+
42+
def test_torch_ref_match():
43+
44+
# Verify if the torch implementation values match the original Numpy implementation.
45+
46+
num_teachers, num_examples, num_labels = (100, 50, 10)
47+
preds = (np.random.rand(num_teachers, num_examples) * num_labels).astype(int) # fake preds
48+
49+
indices = (np.random.rand(num_examples) * num_labels).astype(int) # true answers
50+
51+
preds[:, 0:10] *= 0
52+
53+
data_dep_eps, data_ind_eps = pate.perform_analysis_torch(
54+
preds, indices, noise_eps=0.1, delta=1e-5
55+
)
56+
57+
data_dep_eps_ref, data_ind_eps_ref = pate.perform_analysis(
58+
preds, indices, noise_eps=0.1, delta=1e-5
59+
)
60+
61+
assert torch.isclose(data_dep_eps, torch.tensor(data_dep_eps_ref.item()))
62+
assert torch.isclose(data_ind_eps, torch.tensor(data_ind_eps_ref.item()))

0 commit comments

Comments
 (0)