Skip to content

Commit

Permalink
Secure noise protocol in sequential models.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 14, 2024
1 parent b4b3d6d commit f3a9456
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 163 deletions.
7 changes: 4 additions & 3 deletions tf_shell/cc/optimizers/moduli_autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ StatusOr<bool> DecryptUsesSameContext(utils::MutableNodeView const* node_view,
}
trace = trace->GetRegularFanin(0).node_view();

// The next op should be a tensor list gather. This is how the context was
// created.
if (trace->GetOp() != "TensorListGather") {
// The next op could be a tensor list gather (if the context was created on
// a local device) or a ParseTensor (if the context is read from a cache).
if (trace->GetOp() != "TensorListGather" && trace->GetOp() != "ParseTensor") {
std::cout << "Trace: " << trace->node()->DebugString() << std::endl;
return errors::InvalidArgument(
"Traceback to context expected the second op to be a tensor list "
"gather, but found ",
Expand Down
148 changes: 80 additions & 68 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class DpSgdSequential(SequentialBase):
def __init__(
self,
layers,
shell_context_fn,
backprop_context_fn,
noise_context_fn,
labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
needs_public_rotation_key=False,
Expand All @@ -50,7 +51,8 @@ def __init__(
# be a no-op.
self.layers[-1].activation_deriv = None

self.shell_context_fn = shell_context_fn
self.backprop_context_fn = backprop_context_fn
self.noise_context_fn = noise_context_fn
self.labels_party_dev = labels_party_dev
self.features_party_dev = features_party_dev
self.needs_public_rotation_key = needs_public_rotation_key
Expand Down Expand Up @@ -96,17 +98,8 @@ def compute_max_two_norm_and_pred(self, features):
)
# ^ layers list x (batch size x num output classes x weights) matrix

if len(grads) == 0:
raise ValueError("No gradients found")
slot_size = tf.shape(grads[0])[0]
num_output_classes = grads[0].shape[1]

flat_grads = [tf.reshape(g, [slot_size, num_output_classes, -1]) for g in grads]
# ^ layers list (batch_size x num output classes x flattened layer weights)
all_grads = tf.concat(flat_grads, axis=2)
# ^ batch_size x num output classes x all flattened weights
two_norms = tf.map_fn(lambda x: tf.norm(x, axis=0), all_grads)
max_two_norm = tf.reduce_max(two_norms)
all_grads, _, _, _ = self.flatten_jacobian_list(grads)
max_two_norm = self.flat_jacobian_two_norm(all_grads)
return max_two_norm, y_pred

def train_step(self, data):
Expand All @@ -116,21 +109,24 @@ def train_step(self, data):
if self.disable_encryption:
enc_y = y
else:
shell_context = self.shell_context_fn()
secret_key = tf_shell.create_key64(shell_context, self.cache_path)
secret_fast_rotation_key = tf_shell.create_fast_rotation_key64(
shell_context, secret_key, self.cache_path
backprop_context = self.backprop_context_fn()
backprop_secret_key = tf_shell.create_key64(
backprop_context, self.cache_path
)
backprop_secret_fastrot_key = tf_shell.create_fast_rotation_key64(
backprop_context, backprop_secret_key, self.cache_path
)
if self.needs_public_rotation_key:
public_rotation_key = tf_shell.create_rotation_key64(
shell_context, secret_key, self.cache_path
public_backprop_rotation_key = tf_shell.create_rotation_key64(
backprop_context, backprop_secret_key, self.cache_path
)
else:
public_backprop_rotation_key = None
# Encrypt the batch of secret labels y.
enc_y = tf_shell.to_encrypted(y, secret_key, shell_context)
enc_y = tf_shell.to_encrypted(y, backprop_secret_key, backprop_context)

with tf.device(self.features_party_dev):
# Forward pass in plaintext.
# y_pred = self(x, training=True)
max_two_norm, y_pred = self.compute_max_two_norm_and_pred(x)

# Backward pass.
Expand All @@ -141,30 +137,19 @@ def train_step(self, data):
if isinstance(l, tf_shell_ml.GlobalAveragePooling1D):
dw, dx = l.backward(dJ_dx[-1])
else:
dw, dx = l.backward(
dJ_dx[-1],
(
public_rotation_key
if not self.disable_encryption
and self.needs_public_rotation_key
else None
),
)
dw, dx = l.backward(dJ_dx[-1], public_backprop_rotation_key)
dJ_dw.extend(dw)
dJ_dx.append(dx)

if len(dJ_dw) == 0:
raise ValueError("No gradients found.")

# Mask the encrypted grads to prepare for decryption. The masks may
# overflow during the reduce_sum over the batch. When the masks are
# operated on, they are multiplied by the scaling factor, so it is
# not necessary to mask the full range -t/2 to t/2. (Though it is
# possible, it unnecessarily introduces noise into the ciphertext.)
if self.disable_masking or self.disable_encryption:
masked_enc_grads = [g for g in reversed(dJ_dw)]
grads = [g for g in reversed(dJ_dw)]
else:
t = tf.cast(shell_context.plaintext_modulus, tf.float32)
t = tf.cast(backprop_context.plaintext_modulus, tf.float32)
t_half = t // 2
mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)]
mask = [
Expand All @@ -181,35 +166,77 @@ def train_step(self, data):

# Mask the encrypted gradients and reverse the order to match
# the order of the layers.
masked_enc_grads = [(g + m) for g, m in zip(reversed(dJ_dw), mask)]
grads = [(g + m) for g, m in zip(reversed(dJ_dw), mask)]

if not self.disable_noise:
# Features party encrypts the max two norm to send to the labels
# party so they can scale the noise.
noise_context = self.noise_context_fn()
noise_secret_key = tf_shell.create_key64(noise_context, self.cache_path)
max_two_norm = tf.expand_dims(max_two_norm, 0)
max_two_norm = tf.repeat(max_two_norm, noise_context.num_slots, axis=0)
enc_max_two_norm = tf_shell.to_encrypted(
max_two_norm, noise_secret_key, noise_context
)

with tf.device(self.labels_party_dev):
if self.disable_encryption:
# Unpacking is not necessary when not using encryption.
masked_grads = masked_enc_grads
else:
# Decrypt the weight gradients.
packed_masked_grads = [
if not self.disable_encryption:
# Decrypt the weight gradients with the backprop key.
packed_grads = [
tf_shell.to_tensorflow(
g,
secret_fast_rotation_key if g._is_fast_rotated else secret_key,
(
backprop_secret_fastrot_key
if g._is_fast_rotated
else backprop_secret_key
),
)
for g in masked_enc_grads
for g in grads
]

# Unpack the plaintext gradients using the corresponding layer's
# unpack function.
# TODO: Make sure this doesn't require sending the layers
# themselves just for unpacking. The weights should not be
# shared with the labels party.
masked_grads = [
f(g) for f, g in zip(self.unpacking_funcs, packed_masked_grads)
]
grads = [f(g) for f, g in zip(self.unpacking_funcs, packed_grads)]

if not self.disable_noise:
# Efficiently pack the masked gradients to prepare for
# encryption. This is special because the masked gradients are
# no longer batched so the packing must be done manually.
(
flat_grads,
grad_shapes,
flattened_grad_shapes,
total_grad_size,
) = self.flatten_and_pad_grad_list(grads, noise_context.num_slots)

# Sample the noise.
noise = tf.random.normal(
tf.shape(flat_grads),
stddev=self.noise_multiplier,
dtype=float,
)
# Scale it by the encrypted max two norm.
enc_noise = enc_max_two_norm * noise
# Add the encrypted noise to the flat masked gradients.
grads = enc_noise + flat_grads

with tf.device(self.features_party_dev):
if self.disable_masking or self.disable_encryption:
grads = masked_grads
else:
if not self.disable_noise:
# The gradients must be first be decrypted using the noise
# secret key.
flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key)
# Unpack the noise after decryption.
grads = self.unflatten_and_unpad_grad(
flat_grads,
grad_shapes,
flattened_grad_shapes,
total_grad_size,
)

if not self.disable_masking and not self.disable_encryption:
# SHELL represents floats as integers between [0, t) where t is the
# plaintext modulus. To mimic the modulo operation without SHELL,
# numbers which exceed the range [-t/2, t/2) are shifted back into
Expand All @@ -232,7 +259,7 @@ def rebalance(x, s):
unpacked_mask = [
rebalance(m, s) for m, s in zip(unpacked_mask, mask_scaling_factors)
]
grads = [mg - m for mg, m in zip(masked_grads, unpacked_mask)]
grads = [mg - m for mg, m in zip(grads, unpacked_mask)]
grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)]

if not self.disable_encryption and self.check_overflow:
Expand All @@ -251,23 +278,8 @@ def rebalance(x, s):
lambda: tf.identity(overflowed),
)

# TODO: set stddev based on clipping threshold.
if self.disable_noise:
noised_grads = grads
else:
noise = [
tf.random.normal(
tf.shape(g),
stddev=max_two_norm * self.noise_multiplier,
dtype=float,
)
# tf.zeros(tf.shape(g))
for g in grads
]
noised_grads = [g + n for g, n in zip(grads, noise)]

# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(noised_grads, self.weights))
self.optimizer.apply_gradients(zip(grads, self.weights))

# Do not update metrics during secure training.
if self.disable_encryption:
Expand All @@ -281,7 +293,7 @@ def rebalance(x, s):

metric_results = {m.name: m.result() for m in self.metrics}
else:
metric_results = {"num_slots": shell_context.num_slots}
metric_results = {"num_slots": backprop_context.num_slots}

return metric_results

Expand Down
125 changes: 125 additions & 0 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,128 @@ def fit(self, train_dataset, **kwargs):
train_dataset = self.prep_dataset_for_model(train_dataset)

return super().fit(train_dataset, **kwargs)

def flatten_jacobian_list(self, grads):
"""Takes as input a jacobian and flattens into a single tensor. The
jacobian is expected to be of the form:
layers list x (batch size x num output classes x weights)
where weights may be any shape. The output is a tensor of shape:
batch_size x num output classes x all flattened weights
where all flattened weights include weights from all the layers.
"""
if len(grads) == 0:
raise ValueError("No gradients found")

# Get the shapes from TensorFlow's tensors, not SHELL's context for when
# the batch size != slotting dim or not using encryption.
slot_size = tf.shape(grads[0])[0]
num_output_classes = grads[0].shape[1]
grad_shapes = [g.shape[2:] for g in grads]
flattened_grad_shapes = [s.num_elements() for s in grad_shapes]

flat_grads = [
tf.reshape(g, [slot_size, num_output_classes, s])
for g, s in zip(grads, flattened_grad_shapes)
]
# ^ layers list (batch_size x num output classes x flattened layer weights)

all_grads = tf.concat(flat_grads, axis=2)
# ^ batch_size x num output classes x all flattened weights

return all_grads, slot_size, grad_shapes, flattened_grad_shapes

def flat_jacobian_two_norm(self, flat_jacobian):
"""Computes the maximum L2 norm of a flattened jacobian. The input is
expected to be of shape:
batch_size x num output classes x all flattened weights
where all flattened weights includes weights from all the layers."""
# The DP sensitivity of backprop when revealing the sum of gradients
# across a batch, is the maximum L2 norm of the gradient over every
# example in the batch, and over every class.
two_norms = tf.map_fn(lambda x: tf.norm(x, axis=0), flat_jacobian)
# ^ batch_size x num output classes
max_two_norm = tf.reduce_max(two_norms)
# ^ scalar
return max_two_norm

def unflatten_batch_grad_list(
self, flat_grads, slot_size, grad_shapes, flattened_grad_shapes
):
"""Takes as input a flattened gradient tensor and unflattens it into a
list of tensors. This is useful to undo the flattening performed by
flat_jacobian_list() after the output class dimension has been reduced.
The input is expected to be of shape:
batch_size x all flattened weights
where all flattened weights includes weights from all the layers. The
output is a list of tensors of shape:
layers list x (batch size x weights)
"""
# Split to recover the gradients by layer.
grad_list = tf_shell.split(flat_grads, flattened_grad_shapes, axis=1)
# ^ layers list (batch_size x flattened weights)

# Unflatten the gradients to the original layer shape.
grad_list = [
tf_shell.reshape(
g,
tf.concat([[slot_size], tf.cast(s, dtype=tf.int64)], axis=0),
)
for g, s in zip(grad_list, grad_shapes)
]
# ^ layers list (batch_size x weights)
return grad_list

def flatten_and_pad_grad_list(self, grads_list, slot_size):
"""Takes as input a list of tensors and flattens them into a single
tensor. The input is expected to be of shape:
layers list x (weights)
where weights may be any shape. The output is a tensor of shape:
slot_size x remaining flattened weights
which is the input weights flattened, concatenated, and padded out to
make the output shape non-ragged.
"""
if len(grads_list) == 0:
raise ValueError("No gradients found")

grad_shapes = [g.shape for g in grads_list]
flattened_grad_shapes = [s.num_elements() for s in grad_shapes]
total_grad_size = sum(flattened_grad_shapes)

flat_grad_list = [tf.reshape(g, [-1]) for g in grads_list]
# ^ layers list x (flattened weights)

flat_grads = tf.concat(flat_grad_list, axis=0)
# ^ all flattened weights

pad_len = slot_size - tf.math.floormod(
tf.cast(total_grad_size, dtype=tf.int64), slot_size
)
padded_flat_grads = tf.concat([flat_grads, tf.zeros(pad_len)], axis=0)
out = tf.reshape(padded_flat_grads, [slot_size, -1])

return out, grad_shapes, flattened_grad_shapes, total_grad_size

def unflatten_and_unpad_grad(
self, flat_grads, grad_shapes, flattened_grad_shapes, total_grad_size
):
"""Takes as input a flattened and padded gradient tensor and unflattens
it into a list of tensors. This undoes the flattening and padding
introduced by flatten_and_pad_grad_list(). The input is expected to be
of shape:
slot_size x remaining flattened weights
The output is a list of tensors of shape:
layers list x (weights)
"""

# First reshape to a flat tensor.
flat_grads = tf.reshape(flat_grads, [-1])

# Remove the padding.
flat_grads = flat_grads[:total_grad_size]

# Split the flat tensor into the original shapes.
grads_list = tf_shell.split(flat_grads, flattened_grad_shapes, axis=0)

# Reshape to the original shapes.
grads_list = [tf.reshape(g, s) for g, s in zip(grads_list, grad_shapes)]
return grads_list
Loading

0 comments on commit f3a9456

Please sign in to comment.