Skip to content

Commit d631d2b

Browse files
authored
Merge pull request OpenMined#1967 from kamathhrishi/cifar10_exp
PATE
2 parents 72d8316 + cd84b1f commit d631d2b

File tree

2 files changed

+289
-3
lines changed

2 files changed

+289
-3
lines changed

syft/frameworks/torch/differential_privacy/pate.py

+244-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
import math
2626
import numpy as np
27+
import torch
2728

2829
# import tensorflow as tf
2930
#
@@ -50,7 +51,7 @@
5051

5152

5253
def compute_q_noisy_max(counts, noise_eps):
53-
"""returns ~ Pr[outcome != winner].
54+
"""Returns ~ Pr[outcome != winner].
5455
5556
Args:
5657
counts: a list of scores
@@ -65,16 +66,19 @@ def compute_q_noisy_max(counts, noise_eps):
6566

6667
winner = np.argmax(counts)
6768
counts_normalized = noise_eps * (counts - counts[winner])
69+
6870
counts_rest = np.array([counts_normalized[i] for i in range(len(counts)) if i != winner])
6971
q = 0.0
7072
for c in counts_rest:
7173
gap = -c
74+
7275
q += (gap + 2.0) / (4.0 * math.exp(gap))
76+
7377
return min(q, 1.0 - (1.0 / len(counts)))
7478

7579

7680
def compute_q_noisy_max_approx(counts, noise_eps):
77-
"""returns ~ Pr[outcome != winner].
81+
"""Returns ~ Pr[outcome != winner].
7882
7983
Args:
8084
counts: a list of scores
@@ -213,24 +217,28 @@ def perform_analysis(teacher_preds, indices, noise_eps, delta=1e-5, moments=8, b
213217

214218
assert num_examples == _num_examples
215219

216-
counts_mat = np.zeros((num_examples, num_labels)).astype(np.int32)
220+
counts_mat = np.zeros((num_examples, num_labels))
217221

218222
for i in range(num_examples):
219223
for j in range(num_teachers):
220224
counts_mat[i, int(teacher_preds[j, i])] += 1
221225

222226
l_list = 1.0 + np.array(range(moments))
227+
223228
total_log_mgf_nm = np.array([0.0 for _ in l_list])
224229
total_ss_nm = np.array([0.0 for _ in l_list])
225230

226231
for i in indices:
232+
227233
total_log_mgf_nm += np.array(
228234
[logmgf_from_counts(counts_mat[i], noise_eps, l) for l in l_list]
229235
)
236+
230237
total_ss_nm += np.array([smoothed_sens(counts_mat[i], noise_eps, l, beta) for l in l_list])
231238

232239
# We want delta = exp(alpha - eps l).
233240
# Solving gives eps = (alpha - ln (delta))/l
241+
234242
eps_list_nm = (total_log_mgf_nm - math.log(delta)) / l_list
235243

236244
# print("Epsilons (Noisy Max): " + str(eps_list_nm))
@@ -266,3 +274,236 @@ def perform_analysis(teacher_preds, indices, noise_eps, delta=1e-5, moments=8, b
266274
# print("Data independent bound = " + str(min(data_ind_eps_list)) + ".")
267275

268276
return min(eps_list_nm), min(data_ind_eps_list)
277+
278+
279+
def tensors_to_literals(tensor_list):
280+
"""Converts list of torch tensors to list of integers/floats. Fix for not having the functionality which converts list of tensors to tensors
281+
282+
Args:
283+
284+
tensor_list[List]: List of torch tensors
285+
286+
Returns:
287+
288+
literal_list[List]: List of floats/integers
289+
290+
"""
291+
292+
literal_list = []
293+
294+
for tensor in tensor_list:
295+
literal_list.append(tensor.item())
296+
297+
return literal_list
298+
299+
300+
def logmgf_exact_torch(q, priv_eps, l):
301+
"""Computes the logmgf value given q and privacy eps.
302+
The bound used is the min of three terms. The first term is from
303+
https://arxiv.org/pdf/1605.02065.pdf.
304+
The second term is based on the fact that when event has probability (1-q) for
305+
q close to zero, q can only change by exp(eps), which corresponds to a
306+
much smaller multiplicative change in (1-q)
307+
The third term comes directly from the privacy guarantee.
308+
Args:
309+
q: pr of non-optimal outcome
310+
priv_eps: eps parameter for DP
311+
l: moment to compute.
312+
Returns:
313+
Upper bound on logmgf
314+
"""
315+
if q < 0.5:
316+
t_one = (1 - q) * math.pow((1 - q) / (1 - math.exp(priv_eps) * q), l)
317+
t_two = q * math.exp(priv_eps * l)
318+
t = t_one + t_two
319+
try:
320+
321+
log_t = math.log(t)
322+
323+
except ValueError:
324+
325+
print("Got ValueError in math.log for values :" + str((q, priv_eps, l, t)))
326+
log_t = priv_eps * l
327+
else:
328+
329+
log_t = priv_eps * l
330+
331+
return min(0.5 * priv_eps * priv_eps * l * (l + 1), log_t, priv_eps * l)
332+
333+
334+
def compute_q_noisy_max_torch(counts, noise_eps):
335+
"""Returns ~ Pr[outcome != winner].
336+
Args:
337+
338+
counts: a list of scores
339+
noise_eps: privacy parameter for noisy_max
340+
341+
Returns:
342+
343+
q: the probability that outcome is different from true winner.
344+
345+
"""
346+
347+
if type(counts) != torch.tensor:
348+
349+
counts = torch.tensor(tensors_to_literals(counts), dtype=torch.float)
350+
351+
_, winner = counts.max(0)
352+
counts_normalized = noise_eps * (
353+
torch.tensor(counts, dtype=torch.float) - torch.tensor(counts[winner], dtype=torch.float)
354+
)
355+
356+
counts_normalized = tensors_to_literals(counts_normalized)
357+
counts_rest = torch.tensor(
358+
[counts_normalized[i] for i in range(len(counts)) if i != winner], dtype=torch.float
359+
)
360+
q = 0.0
361+
362+
index = 0
363+
for c in counts_rest:
364+
365+
gap = -c
366+
q += (gap + 2.0) / (4.0 * math.exp(gap))
367+
368+
index += 1
369+
370+
return min(q, 1.0 - (1.0 / len(counts)))
371+
372+
373+
def logmgf_from_counts_torch(counts, noise_eps, l):
374+
375+
"""
376+
ReportNoisyMax mechanism with noise_eps with 2*noise_eps-DP
377+
in our setting where one count can go up by one and another
378+
can go down by 1.
379+
"""
380+
381+
q = compute_q_noisy_max_torch(counts, noise_eps)
382+
383+
return logmgf_exact_torch(q, 2.0 * noise_eps, l)
384+
385+
386+
def sens_at_k_torch(counts, noise_eps, l, k):
387+
388+
"""Return sensitivity at distane k.
389+
Args:
390+
391+
counts: an array of scores
392+
noise_eps: noise parameter used
393+
l: moment whose sensitivity is being computed
394+
k: distance
395+
Returns:
396+
sensitivity: at distance k
397+
"""
398+
399+
counts_sorted = sorted(counts, reverse=True)
400+
401+
if 0.5 * noise_eps * l > 1:
402+
403+
print("l too large to compute sensitivity")
404+
return 0
405+
406+
if counts[0] < counts[1] + k:
407+
408+
return 0
409+
410+
counts_sorted[0] -= k
411+
counts_sorted[1] += k
412+
val = logmgf_from_counts_torch(counts_sorted, noise_eps, l)
413+
counts_sorted[0] -= 1
414+
counts_sorted[1] += 1
415+
val_changed = logmgf_from_counts_torch(counts_sorted, noise_eps, l)
416+
return val_changed - val
417+
418+
419+
def smooth_sens_torch(counts, noise_eps, l, beta):
420+
421+
"""Compute beta-smooth sensitivity.
422+
423+
Args:
424+
counts: array of scors
425+
noise_eps: noise parameter
426+
l: moment of interest
427+
beta: smoothness parameter
428+
Returns:
429+
smooth_sensitivity: a beta smooth upper bound
430+
"""
431+
432+
k = 0
433+
smoothed_sensitivity = sens_at_k_torch(counts, noise_eps, l, k)
434+
435+
while k < max(counts):
436+
437+
k += 1
438+
sensitivity_at_k = sens_at_k_torch(counts, noise_eps, l, k)
439+
smoothed_sensitivity = max(smoothed_sensitivity, math.exp(-beta * k) * sensitivity_at_k)
440+
if sensitivity_at_k == 0.0:
441+
break
442+
443+
return smoothed_sensitivity
444+
445+
446+
def perform_analysis_torch(preds, indices, noise_eps=0.1, delta=1e-5, moments=8, beta=0.09):
447+
"""Performs PATE analysis on predictions from teachers and combined predictions for student.
448+
Args:
449+
teacher_preds: a torch tensor of dim (num_teachers x num_examples). Each value corresponds to the
450+
index of the label which a teacher gave for a specific example
451+
indices: a torch tensor of dim (num_examples) of aggregated examples which were aggregated using
452+
the noisy max mechanism.
453+
noise_eps: the epsilon level used to create the indices
454+
delta: the desired level of delta
455+
moments: the number of moments to track (see the paper)
456+
beta: a smoothing parameter (see the paper)
457+
Returns:
458+
tuple: first value is the data dependent epsilon, then the data independent epsilon
459+
"""
460+
461+
num_teachers, num_examples = preds.shape
462+
_num_examples = indices.shape[0]
463+
464+
assert num_examples == _num_examples
465+
466+
labels = list(preds.flatten())
467+
labels = set([tensor.item() for tensor in labels])
468+
num_labels = len(labels)
469+
470+
counts_mat = torch.zeros(num_examples, num_labels, dtype=torch.float32)
471+
472+
for i in range(num_examples):
473+
474+
for j in range(num_teachers):
475+
476+
counts_mat[i, int(preds[j, i])] += 1
477+
478+
l_list = 1 + torch.tensor(range(moments), dtype=torch.float)
479+
480+
total_log_mgf_nm = torch.tensor([0.0 for _ in l_list], dtype=torch.float)
481+
total_ss_nm = torch.tensor([0.0 for _ in l_list], dtype=torch.float)
482+
483+
for i in indices:
484+
485+
total_log_mgf_nm += torch.tensor(
486+
[logmgf_from_counts_torch(counts_mat[i].clone(), noise_eps, l) for l in l_list]
487+
)
488+
489+
total_ss_nm += torch.tensor(
490+
[smooth_sens_torch(counts_mat[i].clone(), noise_eps, l, beta) for l in l_list],
491+
dtype=torch.float,
492+
)
493+
494+
eps_list_nm = (total_log_mgf_nm - math.log(delta)) / l_list
495+
ss_eps = 2.0 * beta * math.log(1 / delta)
496+
ss_scale = 2.0 / ss_eps
497+
if min(eps_list_nm) == eps_list_nm[-1]:
498+
print(
499+
"Warning: May not have used enough values of l. Increase 'moments' variable and run again."
500+
)
501+
502+
data_ind_log_mgf = torch.tensor([0.0 for _ in l_list])
503+
data_ind_log_mgf += num_examples * torch.tensor(
504+
tensors_to_literals([logmgf_exact_torch(1.0, 2.0 * noise_eps, l) for l in l_list])
505+
)
506+
507+
data_ind_eps_list = (data_ind_log_mgf - math.log(delta)) / l_list
508+
509+
return min(eps_list_nm), min(data_ind_eps_list)

test/torch/differential_privacy/test_pate.py

+45
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import numpy as np
2+
3+
import torch
4+
25
from syft.frameworks.torch.differential_privacy import pate
36

7+
np.random.seed(0)
8+
49

510
def test_base_dataset():
611

712
num_teachers, num_examples, num_labels = (100, 50, 10)
813
preds = (np.random.rand(num_teachers, num_examples) * num_labels).astype(int) # fake preds
14+
915
indices = (np.random.rand(num_examples) * num_labels).astype(int) # true answers
1016

1117
preds[:, 0:10] *= 0
@@ -15,3 +21,42 @@ def test_base_dataset():
1521
)
1622

1723
assert data_dep_eps < data_ind_eps
24+
25+
26+
def test_base_dataset_torch():
27+
28+
num_teachers, num_examples, num_labels = (100, 50, 10)
29+
preds = (np.random.rand(num_teachers, num_examples) * num_labels).astype(int) # fake preds
30+
31+
indices = (np.random.rand(num_examples) * num_labels).astype(int) # true answers
32+
33+
preds[:, 0:10] *= 0
34+
35+
data_dep_eps, data_ind_eps = pate.perform_analysis_torch(
36+
preds, indices, noise_eps=0.1, delta=1e-5
37+
)
38+
39+
assert data_dep_eps < data_ind_eps
40+
41+
42+
def test_torch_ref_match():
43+
44+
# Verify if the torch implementation values match the original Numpy implementation.
45+
46+
num_teachers, num_examples, num_labels = (100, 50, 10)
47+
preds = (np.random.rand(num_teachers, num_examples) * num_labels).astype(int) # fake preds
48+
49+
indices = (np.random.rand(num_examples) * num_labels).astype(int) # true answers
50+
51+
preds[:, 0:10] *= 0
52+
53+
data_dep_eps, data_ind_eps = pate.perform_analysis_torch(
54+
preds, indices, noise_eps=0.1, delta=1e-5
55+
)
56+
57+
data_dep_eps_ref, data_ind_eps_ref = pate.perform_analysis(
58+
preds, indices, noise_eps=0.1, delta=1e-5
59+
)
60+
61+
assert torch.isclose(data_dep_eps, torch.tensor(data_dep_eps_ref.item()))
62+
assert torch.isclose(data_ind_eps, torch.tensor(data_ind_eps_ref.item()))

0 commit comments

Comments
 (0)