Skip to content

Commit

Permalink
Tensorflow 2.18 compatibility.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 29, 2024
1 parent d4e69aa commit 6d070a9
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 165 deletions.
8 changes: 6 additions & 2 deletions tf_shell/cc/ops/shell_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ REGISTER_OP("ReduceSumWithModulusPt64")
.Output("reduced: dtype")
.SetShapeFn([](InferenceContext* c) {
tsl::int32 rank = c->Rank(c->input(1));
if (rank == -1) {
c->set_output(0, c->UnknownShape());
return OkStatus();
}

tsl::int32 axis;
TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
Expand All @@ -250,8 +254,8 @@ REGISTER_OP("ReduceSumWithModulusPt64")
clamped_axis += rank;
}
if (clamped_axis < 0 || clamped_axis > rank) {
return InvalidArgument("axis must be in the range [0, rank], got ",
clamped_axis);
return InvalidArgument("axis must be in the range [0, rank], got axis ",
clamped_axis, " and rank ", rank);
}

ShapeHandle output;
Expand Down
93 changes: 26 additions & 67 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,6 @@ class DpSgdSequential(SequentialBase):
def __init__(
self,
layers,
backprop_context_fn,
noise_context_fn,
labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
noise_multiplier=1.0,
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow_INSECURE=False,
cache_path=None,
jacobian_pfor=False,
jacobian_pfor_iterations=None,
*args,
**kwargs,
):
Expand All @@ -52,31 +40,6 @@ def __init__(
# be a no-op.
self.layers[-1].activation_deriv = None

self.backprop_context_fn = backprop_context_fn
self.noise_context_fn = noise_context_fn
self.labels_party_dev = labels_party_dev
self.features_party_dev = features_party_dev
self.noise_multiplier = noise_multiplier
self.disable_encryption = disable_encryption
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow_INSECURE = check_overflow_INSECURE
self.cache_path = cache_path
self.jacobian_pfor = jacobian_pfor
self.jacobian_pfor_iterations = jacobian_pfor_iterations

def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs):
super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs)

if shell_loss is None:
raise ValueError("shell_loss must be provided")
self.loss_fn = shell_loss

# Keras ignores metrics that are not used during training. When training
# with encryption, the metrics are not updated. Store the metrics so
# they can be recovered during validation.
self.val_metrics = metrics

def call(self, x, training=False):
for l in self.layers:
x = l(x, training=training)
Expand All @@ -103,7 +66,7 @@ def compute_max_two_norm_and_pred(self, features, skip_two_norm):

return y_pred, max_two_norm

def train_step(self, data):
def shell_train_step(self, data):
x, y = data

with tf.device(self.labels_party_dev):
Expand Down Expand Up @@ -281,36 +244,32 @@ def rebalance(x, s):
# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))

# Do not update metrics during secure training.
if self.disable_encryption:
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
metric.update_state(0) # Loss is not known when encrypted.
else:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
else:
else:
if self.disable_encryption:
metric.update_state(y, y_pred)
else:
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(
zeros, zeros
) # No other metrics when encrypted.

metric_results = {m.name: m.result() for m in self.metrics}

# TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics
# passed to compile in it's own dictionary. Keras wants all metrics to
# be returned as a flat dictionary. Here we flatten the dictionary.
result = {}
for key, value in metric_results.items():
if isinstance(value, dict):
result.update(value) # add subdict directly into the dict
else:
result[key] = value # non-subdict elements are just copied

metric_results = {m.name: m.result() for m in self.metrics}
else:
metric_results = {"num_slots": backprop_context.num_slots}

return metric_results

def test_step(self, data):
x, y = data

# Forward pass.
y_pred = self(x, training=False)

# Updates the metrics tracking the loss.
self.compute_loss(y=y, y_pred=y_pred)

# Update the other metrics.
for metric in self.val_metrics:
if metric.name != "loss" and metric.name != "num_slots":
metric.update_state(y, y_pred)

# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.val_metrics}
return result, None if self.disable_encryption else backprop_context.num_slots
79 changes: 61 additions & 18 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,75 @@
import tensorflow as tf
import tensorflow.keras as keras
import tf_shell
import tf_shell_ml


class SequentialBase(keras.Sequential):

def __init__(
self,
layers,
backprop_context_fn,
noise_context_fn,
labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
noise_multiplier=1.0,
cache_path=None,
jacobian_pfor=False,
jacobian_pfor_iterations=None,
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow_INSECURE=False,
*args,
**kwargs,
):
super().__init__(layers, *args, **kwargs)
self.backprop_context_fn = backprop_context_fn
self.noise_context_fn = noise_context_fn
self.labels_party_dev = labels_party_dev
self.features_party_dev = features_party_dev
self.noise_multiplier = noise_multiplier
self.cache_path = cache_path
self.jacobian_pfor = jacobian_pfor
self.jacobian_pfor_iterations = jacobian_pfor_iterations
self.disable_encryption = disable_encryption
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow_INSECURE = check_overflow_INSECURE
self.dataset_prepped = False

def compile(self, shell_loss, **kwargs):
if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy):
raise ValueError(
"The model must be used with the tf-shell version of CategoricalCrossentropy loss function. Saw",
shell_loss,
)
if len(self.layers) > 0 and not (
self.layers[-1].activation is tf.keras.activations.softmax
or self.layers[-1].activation is tf.nn.softmax
):
raise ValueError(
"The model must have a softmax activation function on the final layer. Saw",
self.layers[-1].activation,
)

if shell_loss is None:
raise ValueError("shell_loss must be provided")
self.loss_fn = shell_loss

super().compile(loss=tf.keras.losses.CategoricalCrossentropy(), **kwargs)

def train_step(self, data):
metrics, num_slots = self.shell_train_step(data)
return metrics

@tf.function
def train_step_tf_func(self, data):
return self.train_step(data)
return self.shell_train_step(data)

def shell_train_step(self, data):
raise NotImplementedError() # Should be overloaded by the subclass.

# Prepare the dataset for training with encryption by setting the batch size
# to the same value as the encryption ring degree. Run the training loop once
Expand All @@ -43,31 +96,21 @@ def prep_dataset_for_model(self, train_dataset):

# Run the training loop once on dummy data to figure out the batch size.
tf.config.run_functions_eagerly(False)
metrics = self.train_step_tf_func(next(iter(train_dataset)))
metrics, num_slots = self.train_step_tf_func(next(iter(train_dataset)))

if not isinstance(metrics, dict):
raise ValueError(
f"Expected train_step to return a dict, got {type(metrics)}."
)

if "num_slots" not in metrics:
raise ValueError(
f"Expected train_step to return a dict with key 'num_slots', got {metrics.keys()}."
)

train_dataset = train_dataset.rebatch(
metrics["num_slots"].numpy(), drop_remainder=True
)
train_dataset = train_dataset.rebatch(num_slots.numpy(), drop_remainder=True)

self.dataset_prepped = True
return train_dataset

# Prepare the dataset for training with encryption by setting the batch size
# to the same value as the encryption ring degree. It is faster than
# `prep_dataset_for_model` because it does not execute the graph, instead
# tracing and optimizing the graph and extracting the required parameters.
# tracing and optimizing the graph and extracting the required parameters
# without actually executing the graph.
def fast_prep_dataset_for_model(self, train_dataset):
if not self.disable_encryption:
if self.disable_encryption:
self.dataset_prepped = True
return train_dataset

# Call the training step with keygen to trace the graph. Use a copy
Expand Down Expand Up @@ -101,7 +144,7 @@ def get_tensor_by_name(g, name):

log_n = get_tensor_by_name(optimized_graph, context_node.input[0]).tolist()

train_dataset = train_dataset.rebatch(2**log_n, drop_remainder=True)
train_dataset = train_dataset.unbatch().batch(2**log_n, drop_remainder=True)
self.dataset_prepped = True
return train_dataset

Expand Down
86 changes: 27 additions & 59 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,6 @@ class PostScaleSequential(SequentialBase):
def __init__(
self,
layers,
backprop_context_fn,
noise_context_fn,
labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0",
noise_multiplier=1.0,
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow_INSECURE=False,
cache_path=None,
jacobian_pfor=False,
jacobian_pfor_iterations=None,
*args,
**kwargs,
):
Expand All @@ -51,38 +39,7 @@ def __init__(
# be a no-op.
self.layers[-1].activation_deriv = None

self.backprop_context_fn = backprop_context_fn
self.noise_context_fn = noise_context_fn
self.labels_party_dev = labels_party_dev
self.features_party_dev = features_party_dev
self.clipping_threshold = 10000000
self.context_prepped = False
self.disable_encryption = disable_encryption
self.noise_multiplier = noise_multiplier
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow_INSECURE = check_overflow_INSECURE
self.cache_path = cache_path
self.jacobian_pfor = jacobian_pfor
self.jacobian_pfor_iterations = jacobian_pfor_iterations

def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs):
super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs)

if shell_loss is None:
raise ValueError("shell_loss must be provided")
self.loss_fn = shell_loss

# Keras ignores metrics that are not used during training. When training
# with encryption, the metrics are not updated. Store the metrics so
# they can be recovered during validation.
self.val_metrics = metrics

def _unpack(self, x_list):
batch_size = tf.shape(x_list[0])[0] // 2
return [x[0] + x[batch_size] for x in x_list]

def train_step(self, data):
def shell_train_step(self, data):
x, y = data

with tf.device(self.labels_party_dev):
Expand Down Expand Up @@ -157,10 +114,7 @@ def train_step(self, data):
# Note, checking the backprop gradients requires decryption
# on the features party which breaks security of the protocol.
bp_scaling_factors = [g._scaling_factor for g in grads]
dec_grads = [
tf_shell.to_tensorflow(g, backprop_secret_fastrot_key)
for g in grads
]
dec_grads = [tf_shell.to_tensorflow(g, secret_key) for g in grads]
self.warn_on_overflow(
dec_grads,
bp_scaling_factors,
Expand Down Expand Up @@ -292,18 +246,32 @@ def rebalance(x, s):
# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))

# Do not update metrics during secure training.
if self.disable_encryption:
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
metric.update_state(0) # Loss is not known when encrypted.
else:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
else:
else:
if self.disable_encryption:
metric.update_state(y, y_pred)
else:
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(
zeros, zeros
) # No other metrics when encrypted.

metric_results = {m.name: m.result() for m in self.metrics}

# TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics
# passed to compile in it's own dictionary. Keras wants all metrics to
# be returned as a flat dictionary. Here we flatten the dictionary.
result = {}
for key, value in metric_results.items():
if isinstance(value, dict):
result.update(value) # add subdict directly into the dict
else:
result[key] = value # non-subdict elements are just copied

metric_results = {m.name: m.result() for m in self.metrics}
else:
metric_results = {"num_slots": backprop_context.num_slots}

return metric_results
return result, None if self.disable_encryption else backprop_context.num_slots
Loading

0 comments on commit 6d070a9

Please sign in to comment.