Skip to content

Commit

Permalink
Privacy preserving keras models optionally encrypt, noise, and mask.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 10, 2024
1 parent 2144ee8 commit 451c2d9
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 195 deletions.
178 changes: 115 additions & 63 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def __init__(
self,
layers,
shell_context_fn,
use_encryption,
labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
needs_public_rotation_key=False,
noise_multiplier=1.0,
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow_close=True,
*args,
**kwargs,
):
Expand All @@ -48,11 +51,14 @@ def __init__(
self.layers[-1].activation_deriv = None

self.shell_context_fn = shell_context_fn
self.use_encryption = use_encryption
self.labels_party_dev = labels_party_dev
self.features_party_dev = features_party_dev
self.clipping_threshold = 10000000
self.needs_public_rotation_key = needs_public_rotation_key
self.noise_multiplier = noise_multiplier
self.disable_encryption = disable_encryption
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow_close = check_overflow_close

def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs):
super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs)
Expand All @@ -78,14 +84,40 @@ def build(self, input_shape):
if hasattr(l, "unpacking_funcs"):
self.unpacking_funcs.extend(l.unpacking_funcs())

def compute_max_two_norm_and_pred(self, features):
with tf.GradientTape(persistent=tf.executing_eagerly()) as tape:
y_pred = self(features, training=True) # forward pass
grads = tape.jacobian(
y_pred,
self.trainable_variables,
parallel_iterations=1,
experimental_use_pfor=False,
)
# ^ layers list x (batch size x num output classes x weights) matrix

if len(grads) == 0:
raise ValueError("No gradients found")
slot_size = tf.shape(grads[0])[0]
num_output_classes = grads[0].shape[1]

flat_grads = [tf.reshape(g, [slot_size, num_output_classes, -1]) for g in grads]
# ^ layers list (batch_size x num output classes x flattened layer weights)
all_grads = tf.concat(flat_grads, axis=2)
# ^ batch_size x num output classes x all flattened weights
two_norms = tf.map_fn(lambda x: tf.norm(x, axis=0), all_grads)
max_two_norm = tf.reduce_max(two_norms)
return max_two_norm, y_pred

def train_step(self, data):
x, y = data

with tf.device(self.labels_party_dev):
if self.use_encryption:
if self.disable_encryption:
enc_y = y
else:
key_path = tempfile.mkdtemp() # Every trace gets a new key.
shell_context = self.shell_context_fn()

shell_context = self.shell_context_fn()
secret_key = tf_shell.create_key64(
shell_context, key_path + "/secret_key"
)
Expand All @@ -98,12 +130,11 @@ def train_step(self, data):
)
# Encrypt the batch of secret labels y.
enc_y = tf_shell.to_encrypted(y, secret_key, shell_context)
else:
enc_y = y

with tf.device(self.features_party_dev):
# Forward pass in plaintext.
y_pred = self(x, training=True)
# y_pred = self(x, training=True)
max_two_norm, y_pred = self.compute_max_two_norm_and_pred(x)

# Backward pass.
dx = self.loss_fn.grad(enc_y, y_pred)
Expand All @@ -115,38 +146,51 @@ def train_step(self, data):
else:
dw, dx = l.backward(
dJ_dx[-1],
public_rotation_key if self.needs_public_rotation_key else None,
(
public_rotation_key
if not self.disable_encryption
and self.needs_public_rotation_key
else None
),
)
dJ_dw.extend(dw)
dJ_dx.append(dx)

if len(dJ_dw) == 0:
raise ValueError("No gradients found.")

# Setup parameters for the masking.
t = tf.cast(shell_context.plaintext_modulus, tf.float32)
t_half = t // 2
mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)]

# Mask the encrypted grads to prepare for decryption.
mask = [
tf.random.uniform(
tf_shell.shape(g),
dtype=tf.float32,
minval=-t_half / s,
maxval=t_half / s,
)
for g, s in zip(reversed(dJ_dw), mask_scaling_factors)
# tf.zeros_like(tf_shell.shape(g), dtype=tf.int64)
# for g in dJ_dw
]
# Mask the encrypted grads to prepare for decryption. The masks may
# overflow during the reduce_sum over the batch. When the masks are
# operated on, they are multiplied by the scaling factor, so it is
# not necessary to mask the full range -t/2 to t/2. (Though it is
# possible, it unnecessarily introduces noise into the ciphertext.)
if self.disable_masking or self.disable_encryption:
masked_enc_grads = [g for g in reversed(dJ_dw)]
else:
t = tf.cast(shell_context.plaintext_modulus, tf.float32)
t_half = t // 2
mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)]
mask = [
tf.random.uniform(
tf_shell.shape(g),
dtype=tf.float32,
minval=-t_half / s,
maxval=t_half / s,
)
for g, s in zip(reversed(dJ_dw), mask_scaling_factors)
# tf.zeros_like(tf_shell.shape(g), dtype=tf.int64)
# for g in dJ_dw
]

# Mask the encrypted gradients and reverse the order to match the
# order of the layers.
masked_enc_grads = [(g + m) for g, m in zip(reversed(dJ_dw), mask)]
# Mask the encrypted gradients and reverse the order to match
# the order of the layers.
masked_enc_grads = [(g + m) for g, m in zip(reversed(dJ_dw), mask)]

with tf.device(self.labels_party_dev):
if self.use_encryption:
if self.disable_encryption:
# Unpacking is not necessary when not using encryption.
masked_grads = masked_enc_grads
else:
# Decrypt the weight gradients.
packed_masked_grads = [
tf_shell.to_tensorflow(
Expand All @@ -164,47 +208,56 @@ def train_step(self, data):
masked_grads = [
f(g) for f, g in zip(self.unpacking_funcs, packed_masked_grads)
]
else:
masked_grads = masked_enc_grads

with tf.device(self.features_party_dev):
# SHELL represents floats as integers between [0, t) where t is the
# plaintext modulus. To mimic the modulo operation without SHELL,
# numbers which exceed the range [-t/2, t/2) are shifted back into
# the range.
def rebalance(x_list, t_half, scaling_factor_list):
x_list = [
tf.where(x > t_half / s + (1 / s - 1e-6), x - t / s, x)
for x, s in zip(x_list, scaling_factor_list)
]
x_list = [
tf.where(x < -t_half / s - (1 / s - 1e-6), x + t / s, x)
for x, s in zip(x_list, scaling_factor_list)
if self.disable_masking or self.disable_encryption:
grads = masked_grads
else:
# SHELL represents floats as integers between [0, t) where t is the
# plaintext modulus. To mimic the modulo operation without SHELL,
# numbers which exceed the range [-t/2, t/2) are shifted back into
# the range.
epsilon = tf.constant(1e-6, dtype=float)

def rebalance(x, s):
r_bound = t_half / s + epsilon
l_bound = -t_half / s - epsilon
t_over_s = t / s
x = tf.where(x > r_bound, x - t_over_s, x)
x = tf.where(x < l_bound, x + t_over_s, x)
return x

# Unmask the gradients using the mask. The unpacking function may
# sum the mask from two of the gradients (one from each batch), so
# the mask must be brought back into the range of [-t/2, t/2] before
# subtracting it from the gradient, and again after.
unpacked_mask = [f(m) for f, m in zip(self.unpacking_funcs, mask)]
unpacked_mask = [
rebalance(m, s) for m, s in zip(unpacked_mask, mask_scaling_factors)
]
return x_list

# Unmask the gradients using the mask. The unpacking function may
# sum the mask from two of the gradients (one from each batch), so
# the mask must be brought back into the range of [-t/2, t/2] before
# subtracting it from the gradient, and again after.
unpacked_mask = [f(m) for f, m in zip(self.unpacking_funcs, mask)]
unpacked_mask = rebalance(unpacked_mask, t_half, mask_scaling_factors)
grads = [mg - m for mg, m in zip(masked_grads, unpacked_mask)]
grads = rebalance(grads, t_half, mask_scaling_factors)
grads = [mg - m for mg, m in zip(masked_grads, unpacked_mask)]
grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)]

# TODO: set stddev based on clipping threshold.
noise = [
# tf.random.normal(tf.shape(g), stddev=1, dtype=float) for g in grads
tf.zeros(tf.shape(g))
for g in grads
]
noised_grads = [g + n for g, n in zip(grads, noise)]
if self.disable_noise:
noised_grads = grads
else:
noise = [
tf.random.normal(
tf.shape(g),
stddev=max_two_norm * self.noise_multiplier,
dtype=float,
)
# tf.zeros(tf.shape(g))
for g in grads
]
noised_grads = [g + n for g, n in zip(grads, noise)]

# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(noised_grads, self.weights))

# Do not update metrics during secure training.
if not self.use_encryption:
if self.disable_encryption:
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
Expand All @@ -215,9 +268,8 @@ def rebalance(x_list, t_half, scaling_factor_list):

metric_results = {m.name: m.result() for m in self.metrics}
else:
metric_results = {}
metric_results = {"num_slots": shell_context.num_slots}

metric_results["num_slots"] = shell_context.num_slots
return metric_results

def test_step(self, data):
Expand Down
19 changes: 15 additions & 4 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,24 @@ def train_step_tf_func(self, data):
# to the same value as the encryption ring degree. Run the training loop once
# on dummy data to figure out the batch size.
def prep_dataset_for_model(self, train_dataset):
if not self.use_encryption:
if self.disable_encryption:
self.dataset_prepped = True
return
return train_dataset

# Run the training loop once on dummy data to figure out the batch size.
tf.config.run_functions_eagerly(False)
metrics = self.train_step_tf_func(next(iter(train_dataset)))

if not isinstance(metrics, dict):
raise ValueError(
f"Expected train_step to return a dict, got {type(metrics)}."
)

if "num_slots" not in metrics:
raise ValueError(
f"Expected train_step to return a dict with key 'num_slots', got {metrics.keys()}."
)

train_dataset = train_dataset.rebatch(
metrics["num_slots"].numpy(), drop_remainder=True
)
Expand All @@ -56,8 +67,8 @@ def prep_dataset_for_model(self, train_dataset):
# `prep_dataset_for_model` because it does not execute the graph, instead
# tracing and optimizing the graph and extracting the required parameters.
def fast_prep_dataset_for_model(self, train_dataset):
if not self.use_encryption:
return
if not self.disable_encryption:
return train_dataset

# Call the training step with keygen to trace the graph. Use a copy
# of the function to avoid caching the trace.
Expand Down
Loading

0 comments on commit 451c2d9

Please sign in to comment.