Skip to content

Commit

Permalink
Reduce memory usage when computing jacobian and two norm.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Nov 5, 2024
1 parent f86add2 commit 84ab0f3
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 158 deletions.
64 changes: 24 additions & 40 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,43 +31,26 @@ def __init__(

if len(self.layers) > 0:
self.layers[0].is_first_layer = True
# Do not set the derivative of the activation function for the last
# layer in the model. The derivative of the categorical crossentropy
# loss function times the derivative of a softmax is just y_pred -
# labels (which is much easier to compute than each of them
# individually). So instead just let the loss function derivative
# incorporate y_pred - labels and let the derivative of this last
# layer's activation be a no-op.
# Do not set the the activation function for the last layer in the
# model. The derivative of the categorical crossentropy loss
# function times the derivative of a softmax is just predictions - labels
# (which is much easier to compute than each of them individually).
# So instead just let the loss function derivative incorporate
# predictions - labels and let the derivative of this last layer's
# activation be a no-op.
self.layers[-1].activation = None
self.layers[-1].activation_deriv = None

def call(self, features, training=False):
out = features
def call(self, features, training=False, with_softmax=True):
predictions = features
for l in self.layers:
out = l(out, training=training)
return out
predictions = l(predictions, training=training)

def compute_max_two_norm_and_pred(self, features, skip_two_norm):
with tf.GradientTape(
persistent=tf.executing_eagerly() or self.jacobian_pfor
) as tape:
y_pred = self(features, training=True) # forward pass

if not skip_two_norm:
grads = tape.jacobian(
y_pred,
self.trainable_variables,
# unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor
parallel_iterations=self.jacobian_pfor_iterations,
experimental_use_pfor=self.jacobian_pfor,
)
# ^ layers list x (batch size x num output classes x weights) matrix

all_grads, _, _, _ = self.flatten_jacobian_list(grads)
max_two_norm = self.flat_jacobian_two_norm(all_grads)
else:
max_two_norm = None

return y_pred, max_two_norm
if not with_softmax:
return predictions
# Perform the last layer activation since it is removed for training
# purposes.
return tf.nn.softmax(predictions)

def shell_train_step(self, features, labels):
with tf.device(self.labels_party_dev):
Expand All @@ -85,14 +68,15 @@ def shell_train_step(self, features, labels):

with tf.device(self.jacobian_device):
features = tf.identity(features) # copy to GPU if needed
# Forward pass in plaintext.
y_pred, max_two_norm = self.compute_max_two_norm_and_pred(
features, self.disable_noise
predictions, jacobians = self.predict_and_jacobian(
features,
skip_jacobian=self.disable_noise, # Jacobian only needed for noise.
)
max_two_norm = self.jacobian_max_two_norm(jacobians)

with tf.device(self.features_party_dev):
# Backward pass.
dx = self.loss_fn.grad(enc_y, y_pred)
dx = self.loss_fn.grad(enc_y, predictions)
dJ_dw = [] # Derivatives of the loss with respect to the weights.
dJ_dx = [dx] # Derivatives of the loss with respect to the inputs.
for l in reversed(self.layers):
Expand Down Expand Up @@ -252,17 +236,17 @@ def rebalance(x, s):
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
loss = self.loss_fn(labels, y_pred)
loss = self.loss_fn(labels, predictions)
metric.update_state(loss)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
if self.disable_encryption:
metric.update_state(labels, y_pred)
metric.update_state(labels, predictions)
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
zeros = tf.broadcast_to(0, tf.shape(predictions))
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}
Expand Down
110 changes: 43 additions & 67 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,75 +299,51 @@ def fit(
callback_list.on_train_end(logs)
return self.history

def flatten_jacobian_list(self, grads):
"""Takes as input a jacobian and flattens into a single tensor. The
jacobian is expected to be of the form:
layers list x (batch size x num output classes x weights)
where weights may be any shape. The output is a tensor of shape:
batch_size x num output classes x all flattened weights
where all flattened weights include weights from all the layers.
"""
if len(grads) == 0:
raise ValueError("No gradients found")
def predict_and_jacobian(self, features, skip_jacobian=False):
with tf.GradientTape(
persistent=tf.executing_eagerly() or self.jacobian_pfor
) as tape:
predictions = self(features, training=True, with_softmax=False)

if skip_jacobian:
jacobians = []
else:
jacobians = tape.jacobian(
predictions,
self.trainable_variables,
# unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor
parallel_iterations=self.jacobian_pfor_iterations,
experimental_use_pfor=self.jacobian_pfor,
)
# ^ layers list x (batch size x num output classes x weights) matrix
# dy_pred_j/dW_sample_class

# Get the shapes from TensorFlow's tensors, not SHELL's context for when
# the batch size != slotting dim or not using encryption.
slot_size = tf.shape(grads[0])[0]
num_output_classes = tf.shape(grads[0])[1]
grad_shapes = [g.shape[2:] for g in grads]
flattened_grad_shapes = [s.num_elements() for s in grad_shapes]
# Compute the last layer's activation manually since we skipped it above.
predictions = tf.nn.softmax(predictions)

flat_grads = [
tf.reshape(g, [slot_size, num_output_classes, s])
for g, s in zip(grads, flattened_grad_shapes)
]
# ^ layers list (batch_size x num output classes x flattened layer weights)

all_grads = tf.concat(flat_grads, axis=2)
# ^ batch_size x num output classes x all flattened weights

return all_grads, slot_size, grad_shapes, flattened_grad_shapes

def flat_jacobian_two_norm(self, flat_jacobian):
"""Computes the maximum L2 norm of a flattened jacobian. The input is
expected to be of shape:
batch_size x num output classes x all flattened weights
where all flattened weights includes weights from all the layers."""
# The DP sensitivity of backprop when revealing the sum of gradients
# across a batch, is the maximum L2 norm of the gradient over every
# example in the batch, and over every class.
two_norms = tf.map_fn(lambda x: tf.norm(x, axis=0), flat_jacobian)
# ^ batch_size x num output classes
max_two_norm = tf.reduce_max(two_norms)
# ^ scalar
return max_two_norm

def unflatten_batch_grad_list(
self, flat_grads, slot_size, grad_shapes, flattened_grad_shapes
):
"""Takes as input a flattened gradient tensor and unflattens it into a
list of tensors. This is useful to undo the flattening performed by
flat_jacobian_list() after the output class dimension has been reduced.
The input is expected to be of shape:
batch_size x all flattened weights
where all flattened weights includes weights from all the layers. The
output is a list of tensors of shape:
layers list x (batch size x weights)
"""
# Split to recover the gradients by layer.
grad_list = tf_shell.split(flat_grads, flattened_grad_shapes, axis=1)
# ^ layers list (batch_size x flattened weights)

# Unflatten the gradients to the original layer shape.
grad_list = [
tf_shell.reshape(
g,
tf.concat([[slot_size], tf.cast(s, dtype=tf.int64)], axis=0),
)
for g, s in zip(grad_list, grad_shapes)
]
# ^ layers list (batch_size x weights)
return grad_list
return predictions, jacobians

def jacobian_max_two_norm(self, jacobians):
"""Takes the output of the jacobian computation and computes the max two
norm of the weights over all examples in the batch and all output
classes. Do this layer-wise to reduce memory usage."""
if len(jacobians) == 0:
return tf.constant(0.0, dtype=tf.keras.backend.floatx())

batch_size = jacobians[0].shape[0]
num_output_classes = jacobians[0].shape[1]
sum_of_squares = tf.zeros(
[batch_size, num_output_classes], dtype=tf.keras.backend.floatx()
)

for j in jacobians:
# Ignore the batch size and num output classes dimensions and
# recover just the number of dimensions in the weights.
num_weight_dims = len(j.shape) - 2
reduce_sum_dims = range(2, 2 + num_weight_dims)
sum_of_squares += tf.reduce_sum(j * j, axis=reduce_sum_dims)

return tf.sqrt(tf.reduce_max(sum_of_squares))

def flatten_and_pad_grad_list(self, grads_list, slot_size):
"""Takes as input a list of tensors and flattens them into a single
Expand Down
84 changes: 33 additions & 51 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,59 +56,41 @@ def shell_train_step(self, features, labels):

with tf.device(self.jacobian_device):
features = tf.identity(features) # copy to GPU if needed

# self.layers[-1].activation = tf.keras.activations.linear
with tf.GradientTape(
persistent=tf.executing_eagerly() or self.jacobian_pfor
) as tape:
y_pred = self(features, training=True, with_softmax=False)

grads = tape.jacobian(
y_pred,
self.trainable_variables,
# unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor
parallel_iterations=self.jacobian_pfor_iterations,
experimental_use_pfor=self.jacobian_pfor,
)
# ^ layers list x (batch size x num output classes x weights) matrix
# dy_pred_j/dW_sample_class

# Compute the activation manually.
y_pred = tf.nn.softmax(y_pred)
predictions, jacobians = self.predict_and_jacobian(features)
max_two_norm = self.jacobian_max_two_norm(jacobians)

with tf.device(self.features_party_dev):
# Compute prediction - labels (where labels may be encrypted).
scalars = enc_y.__rsub__(y_pred) # dJ/dy_pred
scalars = enc_y.__rsub__(predictions) # dJ/dprediction
# ^ batch_size x num output classes.

# Expand the last dim so that the subsequent multiplications are
# broadcasted.
scalars = tf_shell.expand_dims(scalars, axis=-1)
# ^ batch_size x num output classes x 1

# Flatten and remember the original shape of the gradient in order
# to unpack them after the multiplication so they can be applied to
# the model.
grads, slot_size, grad_shapes, flattened_grad_shapes = (
self.flatten_jacobian_list(grads)
)
# ^ batch_size x num output classes x all flattened weights

max_two_norm = self.flat_jacobian_two_norm(grads)

# Scale the gradients.
grads = scalars * grads
# ^ batch_size x num output classes x all flattened weights

# Sum over the output classes.
grads = tf_shell.reduce_sum(grads, axis=1)
# ^ batch_size x all flattened weights

# Recover the original shapes of the gradients.
grads = self.unflatten_batch_grad_list(
grads, slot_size, grad_shapes, flattened_grad_shapes
)
# ^ layers list (batch_size x weights)
# Scale each gradient. Since 'scalars' may be a vector of
# ciphertexts, this requires multiplying plaintext gradient for the
# specific layer (2d) by the ciphertext (scalar).
grads = []
for j in jacobians:
# Ignore the batch size and num output classes dimensions and
# recover just the number of dimensions in the weights.
num_weight_dims = len(j.shape) - 2

# Make the scalars the same shape as the gradients so the
# multiplication can be broadcasted. Doing this inside the loop
# is okay, TensorFlow will reuse the same expanded tensor if
# their dimensions match across iterations.
scalars_exp = scalars
for _ in range(num_weight_dims):
scalars_exp = tf_shell.expand_dims(scalars_exp, axis=-1)

# Scale the jacobian.
scaled_grad = scalars_exp * j
# ^ batch_size x num output classes x weights

# Sum over the output classes. At this point, this is a gradient
# and no longer a jacobian.
scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1)
# ^ batch_size x weights

grads.append(scaled_grad)

# Check if the post-scaled gradients overflowed.
if not self.disable_encryption and self.check_overflow_INSECURE:
Expand Down Expand Up @@ -250,17 +232,17 @@ def rebalance(x, s):
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
loss = self.loss_fn(labels, y_pred)
loss = self.loss_fn(labels, predictions)
metric.update_state(loss)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
if self.disable_encryption:
metric.update_state(labels, y_pred)
metric.update_state(labels, predictions)
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
zeros = tf.broadcast_to(0, tf.shape(predictions))
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}
Expand Down

0 comments on commit 84ab0f3

Please sign in to comment.