diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index 6cbb9e3..a505208 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -31,43 +31,26 @@ def __init__( if len(self.layers) > 0: self.layers[0].is_first_layer = True - # Do not set the derivative of the activation function for the last - # layer in the model. The derivative of the categorical crossentropy - # loss function times the derivative of a softmax is just y_pred - - # labels (which is much easier to compute than each of them - # individually). So instead just let the loss function derivative - # incorporate y_pred - labels and let the derivative of this last - # layer's activation be a no-op. + # Do not set the the activation function for the last layer in the + # model. The derivative of the categorical crossentropy loss + # function times the derivative of a softmax is just predictions - labels + # (which is much easier to compute than each of them individually). + # So instead just let the loss function derivative incorporate + # predictions - labels and let the derivative of this last layer's + # activation be a no-op. + self.layers[-1].activation = None self.layers[-1].activation_deriv = None - def call(self, features, training=False): - out = features + def call(self, features, training=False, with_softmax=True): + predictions = features for l in self.layers: - out = l(out, training=training) - return out + predictions = l(predictions, training=training) - def compute_max_two_norm_and_pred(self, features, skip_two_norm): - with tf.GradientTape( - persistent=tf.executing_eagerly() or self.jacobian_pfor - ) as tape: - y_pred = self(features, training=True) # forward pass - - if not skip_two_norm: - grads = tape.jacobian( - y_pred, - self.trainable_variables, - # unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor - parallel_iterations=self.jacobian_pfor_iterations, - experimental_use_pfor=self.jacobian_pfor, - ) - # ^ layers list x (batch size x num output classes x weights) matrix - - all_grads, _, _, _ = self.flatten_jacobian_list(grads) - max_two_norm = self.flat_jacobian_two_norm(all_grads) - else: - max_two_norm = None - - return y_pred, max_two_norm + if not with_softmax: + return predictions + # Perform the last layer activation since it is removed for training + # purposes. + return tf.nn.softmax(predictions) def shell_train_step(self, features, labels): with tf.device(self.labels_party_dev): @@ -85,14 +68,15 @@ def shell_train_step(self, features, labels): with tf.device(self.jacobian_device): features = tf.identity(features) # copy to GPU if needed - # Forward pass in plaintext. - y_pred, max_two_norm = self.compute_max_two_norm_and_pred( - features, self.disable_noise + predictions, jacobians = self.predict_and_jacobian( + features, + skip_jacobian=self.disable_noise, # Jacobian only needed for noise. ) + max_two_norm = self.jacobian_max_two_norm(jacobians) with tf.device(self.features_party_dev): # Backward pass. - dx = self.loss_fn.grad(enc_y, y_pred) + dx = self.loss_fn.grad(enc_y, predictions) dJ_dw = [] # Derivatives of the loss with respect to the weights. dJ_dx = [dx] # Derivatives of the loss with respect to the inputs. for l in reversed(self.layers): @@ -252,17 +236,17 @@ def rebalance(x, s): for metric in self.metrics: if metric.name == "loss": if self.disable_encryption: - loss = self.loss_fn(labels, y_pred) + loss = self.loss_fn(labels, predictions) metric.update_state(loss) else: # Loss is unknown when encrypted. metric.update_state(0.0) else: if self.disable_encryption: - metric.update_state(labels, y_pred) + metric.update_state(labels, predictions) else: # Other metrics are uknown when encrypted. - zeros = tf.broadcast_to(0, tf.shape(y_pred)) + zeros = tf.broadcast_to(0, tf.shape(predictions)) metric.update_state(zeros, zeros) metric_results = {m.name: m.result() for m in self.metrics} diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index f055752..1c2c78e 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -299,75 +299,51 @@ def fit( callback_list.on_train_end(logs) return self.history - def flatten_jacobian_list(self, grads): - """Takes as input a jacobian and flattens into a single tensor. The - jacobian is expected to be of the form: - layers list x (batch size x num output classes x weights) - where weights may be any shape. The output is a tensor of shape: - batch_size x num output classes x all flattened weights - where all flattened weights include weights from all the layers. - """ - if len(grads) == 0: - raise ValueError("No gradients found") + def predict_and_jacobian(self, features, skip_jacobian=False): + with tf.GradientTape( + persistent=tf.executing_eagerly() or self.jacobian_pfor + ) as tape: + predictions = self(features, training=True, with_softmax=False) + + if skip_jacobian: + jacobians = [] + else: + jacobians = tape.jacobian( + predictions, + self.trainable_variables, + # unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor + parallel_iterations=self.jacobian_pfor_iterations, + experimental_use_pfor=self.jacobian_pfor, + ) + # ^ layers list x (batch size x num output classes x weights) matrix + # dy_pred_j/dW_sample_class - # Get the shapes from TensorFlow's tensors, not SHELL's context for when - # the batch size != slotting dim or not using encryption. - slot_size = tf.shape(grads[0])[0] - num_output_classes = tf.shape(grads[0])[1] - grad_shapes = [g.shape[2:] for g in grads] - flattened_grad_shapes = [s.num_elements() for s in grad_shapes] + # Compute the last layer's activation manually since we skipped it above. + predictions = tf.nn.softmax(predictions) - flat_grads = [ - tf.reshape(g, [slot_size, num_output_classes, s]) - for g, s in zip(grads, flattened_grad_shapes) - ] - # ^ layers list (batch_size x num output classes x flattened layer weights) - - all_grads = tf.concat(flat_grads, axis=2) - # ^ batch_size x num output classes x all flattened weights - - return all_grads, slot_size, grad_shapes, flattened_grad_shapes - - def flat_jacobian_two_norm(self, flat_jacobian): - """Computes the maximum L2 norm of a flattened jacobian. The input is - expected to be of shape: - batch_size x num output classes x all flattened weights - where all flattened weights includes weights from all the layers.""" - # The DP sensitivity of backprop when revealing the sum of gradients - # across a batch, is the maximum L2 norm of the gradient over every - # example in the batch, and over every class. - two_norms = tf.map_fn(lambda x: tf.norm(x, axis=0), flat_jacobian) - # ^ batch_size x num output classes - max_two_norm = tf.reduce_max(two_norms) - # ^ scalar - return max_two_norm - - def unflatten_batch_grad_list( - self, flat_grads, slot_size, grad_shapes, flattened_grad_shapes - ): - """Takes as input a flattened gradient tensor and unflattens it into a - list of tensors. This is useful to undo the flattening performed by - flat_jacobian_list() after the output class dimension has been reduced. - The input is expected to be of shape: - batch_size x all flattened weights - where all flattened weights includes weights from all the layers. The - output is a list of tensors of shape: - layers list x (batch size x weights) - """ - # Split to recover the gradients by layer. - grad_list = tf_shell.split(flat_grads, flattened_grad_shapes, axis=1) - # ^ layers list (batch_size x flattened weights) - - # Unflatten the gradients to the original layer shape. - grad_list = [ - tf_shell.reshape( - g, - tf.concat([[slot_size], tf.cast(s, dtype=tf.int64)], axis=0), - ) - for g, s in zip(grad_list, grad_shapes) - ] - # ^ layers list (batch_size x weights) - return grad_list + return predictions, jacobians + + def jacobian_max_two_norm(self, jacobians): + """Takes the output of the jacobian computation and computes the max two + norm of the weights over all examples in the batch and all output + classes. Do this layer-wise to reduce memory usage.""" + if len(jacobians) == 0: + return tf.constant(0.0, dtype=tf.keras.backend.floatx()) + + batch_size = jacobians[0].shape[0] + num_output_classes = jacobians[0].shape[1] + sum_of_squares = tf.zeros( + [batch_size, num_output_classes], dtype=tf.keras.backend.floatx() + ) + + for j in jacobians: + # Ignore the batch size and num output classes dimensions and + # recover just the number of dimensions in the weights. + num_weight_dims = len(j.shape) - 2 + reduce_sum_dims = range(2, 2 + num_weight_dims) + sum_of_squares += tf.reduce_sum(j * j, axis=reduce_sum_dims) + + return tf.sqrt(tf.reduce_max(sum_of_squares)) def flatten_and_pad_grad_list(self, grads_list, slot_size): """Takes as input a list of tensors and flattens them into a single diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index f8e3f0a..67f5743 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -56,59 +56,41 @@ def shell_train_step(self, features, labels): with tf.device(self.jacobian_device): features = tf.identity(features) # copy to GPU if needed - - # self.layers[-1].activation = tf.keras.activations.linear - with tf.GradientTape( - persistent=tf.executing_eagerly() or self.jacobian_pfor - ) as tape: - y_pred = self(features, training=True, with_softmax=False) - - grads = tape.jacobian( - y_pred, - self.trainable_variables, - # unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor - parallel_iterations=self.jacobian_pfor_iterations, - experimental_use_pfor=self.jacobian_pfor, - ) - # ^ layers list x (batch size x num output classes x weights) matrix - # dy_pred_j/dW_sample_class - - # Compute the activation manually. - y_pred = tf.nn.softmax(y_pred) + predictions, jacobians = self.predict_and_jacobian(features) + max_two_norm = self.jacobian_max_two_norm(jacobians) with tf.device(self.features_party_dev): # Compute prediction - labels (where labels may be encrypted). - scalars = enc_y.__rsub__(y_pred) # dJ/dy_pred + scalars = enc_y.__rsub__(predictions) # dJ/dprediction # ^ batch_size x num output classes. - # Expand the last dim so that the subsequent multiplications are - # broadcasted. - scalars = tf_shell.expand_dims(scalars, axis=-1) - # ^ batch_size x num output classes x 1 - - # Flatten and remember the original shape of the gradient in order - # to unpack them after the multiplication so they can be applied to - # the model. - grads, slot_size, grad_shapes, flattened_grad_shapes = ( - self.flatten_jacobian_list(grads) - ) - # ^ batch_size x num output classes x all flattened weights - - max_two_norm = self.flat_jacobian_two_norm(grads) - - # Scale the gradients. - grads = scalars * grads - # ^ batch_size x num output classes x all flattened weights - - # Sum over the output classes. - grads = tf_shell.reduce_sum(grads, axis=1) - # ^ batch_size x all flattened weights - - # Recover the original shapes of the gradients. - grads = self.unflatten_batch_grad_list( - grads, slot_size, grad_shapes, flattened_grad_shapes - ) - # ^ layers list (batch_size x weights) + # Scale each gradient. Since 'scalars' may be a vector of + # ciphertexts, this requires multiplying plaintext gradient for the + # specific layer (2d) by the ciphertext (scalar). + grads = [] + for j in jacobians: + # Ignore the batch size and num output classes dimensions and + # recover just the number of dimensions in the weights. + num_weight_dims = len(j.shape) - 2 + + # Make the scalars the same shape as the gradients so the + # multiplication can be broadcasted. Doing this inside the loop + # is okay, TensorFlow will reuse the same expanded tensor if + # their dimensions match across iterations. + scalars_exp = scalars + for _ in range(num_weight_dims): + scalars_exp = tf_shell.expand_dims(scalars_exp, axis=-1) + + # Scale the jacobian. + scaled_grad = scalars_exp * j + # ^ batch_size x num output classes x weights + + # Sum over the output classes. At this point, this is a gradient + # and no longer a jacobian. + scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1) + # ^ batch_size x weights + + grads.append(scaled_grad) # Check if the post-scaled gradients overflowed. if not self.disable_encryption and self.check_overflow_INSECURE: @@ -250,17 +232,17 @@ def rebalance(x, s): for metric in self.metrics: if metric.name == "loss": if self.disable_encryption: - loss = self.loss_fn(labels, y_pred) + loss = self.loss_fn(labels, predictions) metric.update_state(loss) else: # Loss is unknown when encrypted. metric.update_state(0.0) else: if self.disable_encryption: - metric.update_state(labels, y_pred) + metric.update_state(labels, predictions) else: # Other metrics are uknown when encrypted. - zeros = tf.broadcast_to(0, tf.shape(y_pred)) + zeros = tf.broadcast_to(0, tf.shape(predictions)) metric.update_state(zeros, zeros) metric_results = {m.name: m.result() for m in self.metrics}