Skip to content

Commit

Permalink
Sequential models support accelerators to compute jacobian, and split…
Browse files Browse the repository at this point in the history
… input datasets.
  • Loading branch information
james-choncholas committed Nov 4, 2024
1 parent e6fb636 commit 2237f5e
Show file tree
Hide file tree
Showing 10 changed files with 521 additions and 187 deletions.
41 changes: 22 additions & 19 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ def __init__(
self.layers[0].is_first_layer = True
# Do not set the derivative of the activation function for the last
# layer in the model. The derivative of the categorical crossentropy
# loss function times the derivative of a softmax is just y_pred - y
# (which is much easier to compute than each of them individually).
# So instead just let the loss function derivative incorporate
# y_pred - y and let the derivative of this last layer's activation
# be a no-op.
# loss function times the derivative of a softmax is just y_pred -
# labels (which is much easier to compute than each of them
# individually). So instead just let the loss function derivative
# incorporate y_pred - labels and let the derivative of this last
# layer's activation be a no-op.
self.layers[-1].activation_deriv = None

def call(self, x, training=False):
def call(self, features, training=False):
out = features
for l in self.layers:
x = l(x, training=training)
return x
out = l(out, training=training)
return out

def compute_max_two_norm_and_pred(self, features, skip_two_norm):
with tf.GradientTape(
Expand All @@ -68,26 +69,28 @@ def compute_max_two_norm_and_pred(self, features, skip_two_norm):

return y_pred, max_two_norm

def shell_train_step(self, data):
x, y = data

def shell_train_step(self, features, labels):
with tf.device(self.labels_party_dev):
if self.disable_encryption:
enc_y = y
enc_y = labels
else:
backprop_context = self.backprop_context_fn()
backprop_secret_key = tf_shell.create_key64(
backprop_context, self.cache_path
)
# Encrypt the batch of secret labels y.
enc_y = tf_shell.to_encrypted(y, backprop_secret_key, backprop_context)
# Encrypt the batch of secret labels.
enc_y = tf_shell.to_encrypted(
labels, backprop_secret_key, backprop_context
)

with tf.device(self.features_party_dev):
with tf.device(self.jacobian_device):
features = tf.identity(features) # copy to GPU if needed
# Forward pass in plaintext.
y_pred, max_two_norm = self.compute_max_two_norm_and_pred(
x, self.disable_noise
features, self.disable_noise
)

with tf.device(self.features_party_dev):
# Backward pass.
dx = self.loss_fn.grad(enc_y, y_pred)
dJ_dw = [] # Derivatives of the loss with respect to the weights.
Expand Down Expand Up @@ -199,7 +202,7 @@ def shell_train_step(self, data):
# secret key.
flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key)

if not self.disable_encryption and self.check_overflow_INSECURE:
if self.check_overflow_INSECURE:
nosie_scaling_factors = grads._scaling_factor
self.warn_on_overflow(
[flat_grads],
Expand Down Expand Up @@ -249,14 +252,14 @@ def rebalance(x, s):
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
loss = self.loss_fn(y, y_pred)
loss = self.loss_fn(labels, y_pred)
metric.update_state(loss)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
if self.disable_encryption:
metric.update_state(y, y_pred)
metric.update_state(labels, y_pred)
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
Expand Down
222 changes: 178 additions & 44 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# limitations under the License.
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.callbacks import CallbackList
import tf_shell
import tf_shell_ml
import time


class SequentialBase(keras.Sequential):
Expand All @@ -32,6 +34,7 @@ def __init__(
cache_path=None,
jacobian_pfor=False,
jacobian_pfor_iterations=None,
jacobian_device=None,
disable_encryption=False,
disable_masking=False,
disable_noise=False,
Expand All @@ -48,6 +51,9 @@ def __init__(
self.cache_path = cache_path
self.jacobian_pfor = jacobian_pfor
self.jacobian_pfor_iterations = jacobian_pfor_iterations
self.jacobian_device = (
features_party_dev if jacobian_device is None else jacobian_device
)
self.disable_encryption = disable_encryption
self.disable_masking = disable_masking
self.disable_noise = disable_noise
Expand All @@ -59,12 +65,6 @@ def __init__(
"WARNING: `jacobian_pfor` may be incompatible with `disable_encryption`."
)

def compile(self, shell_loss, **kwargs):
if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy):
raise ValueError(
"The model must be used with the tf-shell version of CategoricalCrossentropy loss function. Saw",
shell_loss,
)
if len(self.layers) > 0 and not (
self.layers[-1].activation is tf.keras.activations.softmax
or self.layers[-1].activation is tf.nn.softmax
Expand All @@ -74,54 +74,85 @@ def compile(self, shell_loss, **kwargs):
self.layers[-1].activation,
)

if shell_loss is None:
raise ValueError("shell_loss must be provided")
def compile(self, shell_loss, **kwargs):
if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy):
raise ValueError(
"The model must be used with the tf_shell_ml version of CategoricalCrossentropy loss function. Saw",
shell_loss,
)
self.loss_fn = shell_loss

super().compile(loss=tf.keras.losses.CategoricalCrossentropy(), **kwargs)
super().compile(
loss=tf.keras.losses.CategoricalCrossentropy(),
jit_compile=False, # Disable XLA, no CPU op for tf_shell_ml's TensorArrayV2.
**kwargs,
)

def train_step(self, data):
metrics, num_slots = self.shell_train_step(data)
def train_step(self, features, labels):
metrics, num_slots = self.shell_train_step(features, labels)
return metrics

@tf.function
def train_step_tf_func(self, data):
return self.shell_train_step(data)
def train_step_with_keygen(self, features, labels):
return self.shell_train_step(features, labels)

def shell_train_step(self, data):
@tf.function
def train_step_tf_func(self, features, labels):
return self.shell_train_step(features, labels)

def shell_train_step(self, features, labels):
raise NotImplementedError() # Should be overloaded by the subclass.

# Prepare the dataset for training with encryption by setting the batch size
# to the same value as the encryption ring degree. Run the training loop once
# on dummy data to figure out the batch size.
def prep_dataset_for_model(self, train_dataset):
def prep_dataset_for_model(self, train_features, train_labels):
"""Prepare the dataset for training with encryption by setting the batch
size to the same value as the encryption ring degree. Run the training
loop once on dummy data to figure out the batch size.
"""
if self.disable_encryption:
self.batch_size = next(iter(train_features)).shape[0]
self.dataset_prepped = True
return train_dataset
return train_features, train_labels

# Run the training loop once on dummy data to figure out the batch size.
tf.config.run_functions_eagerly(False)
metrics, num_slots = self.train_step_tf_func(next(iter(train_dataset)))
# Use a separate tf.function to avoid caching the trace so keys and
# context are written to cache and read on next trace.
metrics, num_slots = self.train_step_with_keygen(
next(iter(train_features)), next(iter(train_labels))
)

self.batch_size = num_slots.numpy()

train_dataset = train_dataset.rebatch(num_slots.numpy(), drop_remainder=True)
with tf.device(self.features_party_dev):
train_features = train_features.rebatch(
num_slots.numpy(), drop_remainder=True
)
with tf.device(self.labels_party_dev):
train_labels = train_labels.rebatch(num_slots.numpy(), drop_remainder=True)

self.dataset_prepped = True
return train_dataset

# Prepare the dataset for training with encryption by setting the batch size
# to the same value as the encryption ring degree. It is faster than
# `prep_dataset_for_model` because it does not execute the graph, instead
# tracing and optimizing the graph and extracting the required parameters
# without actually executing the graph.
def fast_prep_dataset_for_model(self, train_dataset):
return train_features, train_labels

def fast_prep_dataset_for_model(self, train_features, train_labels):
"""Prepare the dataset for training with encryption by setting the
batch size to the same value as the encryption ring degree. It is faster
than `prep_dataset_for_model` because it does not execute the graph,
instead tracing and optimizing the graph and extracting the required
parameters without actually executing the graph.
Since the graph is not executed, caches for keys and the shell context
are not written to disk.
"""
if self.disable_encryption:
self.batch_size = next(iter(train_features)).shape[0]
self.dataset_prepped = True
return train_dataset
return train_features, train_labels

# Call the training step with keygen to trace the graph. Use a copy
# of the function to avoid caching the trace.
traceable_copy = self.train_step_tf_func
func = traceable_copy.get_concrete_function(next(iter(train_dataset)))
# Call the training step with keygen to trace the graph. Use a copy of
# the function to avoid caching the trace so keys and context are
# written to cache and read on next trace.
func = self.train_step_with_keygen.get_concrete_function(
next(iter(train_features)), next(iter(train_labels))
)

# Optimize the graph using tf_shells HE-specific optimizers.
optimized_func = tf_shell.optimize_shell_graph(
Expand All @@ -148,12 +179,30 @@ def get_tensor_by_name(g, name):
raise ValueError(f"Node {name} not found in graph.")

log_n = get_tensor_by_name(optimized_graph, context_node.input[0]).tolist()
self.batch_size = 2**log_n

with tf.device(self.features_party_dev):
train_features = train_features.rebatch(2**log_n, drop_remainder=True)
with tf.device(self.labels_party_dev):
train_labels = train_labels.rebatch(2**log_n, drop_remainder=True)

train_dataset = train_dataset.unbatch().batch(2**log_n, drop_remainder=True)
self.dataset_prepped = True
return train_dataset
return train_features, train_labels

def fit(self, train_dataset, **kwargs):
def fit(
self,
features_dataset,
labels_dataset,
epochs=1,
batch_size=32,
callbacks=None,
validation_data=None,
steps_per_epoch=None,
verbose=1,
):
"""A custom training loop that supports inputs from multiple datasets,
each of which can be on a different device.
"""
# Prevent TensorFlow from placing ops on devices which were not
# explicitly assigned for security reasons.
tf.config.set_soft_device_placement(False)
Expand All @@ -162,9 +211,93 @@ def fit(self, train_dataset, **kwargs):
tf_shell.enable_optimization()

if not self.dataset_prepped:
train_dataset = self.prep_dataset_for_model(train_dataset)
features_dataset, labels_dataset = self.prep_dataset_for_model(
features_dataset, labels_dataset
)

return super().fit(train_dataset, **kwargs)
# Calculate samples if possible.
if steps_per_epoch is None:
samples = None
else:
samples = steps_per_epoch * self.batch_size

# Initialize callbacks.
callback_list = CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
model=self,
batch_size=self.batch_size,
epochs=epochs,
steps=steps_per_epoch,
samples=samples,
verbose=verbose,
do_validation=validation_data is not None,
metrics=list(self.metrics_names),
)

# Begin training.
callback_list.on_train_begin()
logs = {}

for epoch in range(epochs):
callback_list.on_epoch_begin(epoch, logs)
start_time = time.time()
self.reset_metrics()

# Training loop.
for step, (batch_x, batch_y) in enumerate(
zip(features_dataset, labels_dataset)
):
callback_list.on_train_batch_begin(step, logs)
logs, num_slots = self.train_step_tf_func(batch_x, batch_y)
callback_list.on_train_batch_end(step, logs)
if steps_per_epoch is not None and step + 1 >= steps_per_epoch:
break

# Validation loop.
if validation_data is not None:
# Reset metrics
self.reset_metrics()

for val_x_batch, val_y_batch in validation_data:
val_y_pred = self(val_x_batch, training=False)
# Update validation metrics
for m in self.metrics:
if m.name == "loss":
loss = self.loss_fn(val_y_batch, val_y_pred)
m.update_state(loss)
else:
m.update_state(val_y_batch, val_y_pred)
metric_results = {m.name: m.result() for m in self.metrics}

# TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds
# metrics passed to compile in it's own dictionary. Keras wants
# all metrics to be returned as a flat dictionary. Here we
# flatten the dictionary.
result = {}
for key, value in metric_results.items():
if isinstance(value, dict):
result.update(value) # add subdict directly into the dict
else:
result[key] = value # non-subdict elements are just copied

logs.update({f"val_{name}": result for name, result in result.items()})

# End of epoch.
logs["time"] = time.time() - start_time

# Update the steps in callback parameters with actual steps completed
if steps_per_epoch is None:
steps_per_epoch = step + 1
samples = steps_per_epoch * self.batch_size
callback_list.params["steps"] = steps_per_epoch
callback_list.params["samples"] = samples
callback_list.on_epoch_end(epoch, logs)

# End of training.
callback_list.on_train_end(logs)
return self.history

def flatten_jacobian_list(self, grads):
"""Takes as input a jacobian and flattens into a single tensor. The
Expand All @@ -180,7 +313,7 @@ def flatten_jacobian_list(self, grads):
# Get the shapes from TensorFlow's tensors, not SHELL's context for when
# the batch size != slotting dim or not using encryption.
slot_size = tf.shape(grads[0])[0]
num_output_classes = grads[0].shape[1]
num_output_classes = tf.shape(grads[0])[1]
grad_shapes = [g.shape[2:] for g in grads]
flattened_grad_shapes = [s.num_elements() for s in grad_shapes]

Expand Down Expand Up @@ -292,9 +425,10 @@ def unflatten_and_unpad_grad(
return grads_list

def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message):
# If the gradient is between [-t/2, -t/4] or [t/4, t/2], the gradient
# may have overflowed. This also must take the scaling factor into
# account so the range is divided by the scaling factor.
"""If the gradient is between [-t/2, -t/4] or [t/4, t/2], the gradient
may have overflowed. This also must take the scaling factor into account
so the range is divided by the scaling factor.
"""
t = tf.cast(plaintext_modulus, grads[0].dtype)
t_half = t / 2

Expand Down
Loading

0 comments on commit 2237f5e

Please sign in to comment.