diff --git a/tf_shell/cc/ops/shell_ops.cc b/tf_shell/cc/ops/shell_ops.cc index 28d71a5..2bb21d1 100644 --- a/tf_shell/cc/ops/shell_ops.cc +++ b/tf_shell/cc/ops/shell_ops.cc @@ -241,6 +241,10 @@ REGISTER_OP("ReduceSumWithModulusPt64") .Output("reduced: dtype") .SetShapeFn([](InferenceContext* c) { tsl::int32 rank = c->Rank(c->input(1)); + if (rank == -1) { + c->set_output(0, c->UnknownShape()); + return OkStatus(); + } tsl::int32 axis; TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); @@ -250,8 +254,8 @@ REGISTER_OP("ReduceSumWithModulusPt64") clamped_axis += rank; } if (clamped_axis < 0 || clamped_axis > rank) { - return InvalidArgument("axis must be in the range [0, rank], got ", - clamped_axis); + return InvalidArgument("axis must be in the range [0, rank], got axis ", + clamped_axis, " and rank ", rank); } ShapeHandle output; diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index a38b9c4..dbe4949 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -24,18 +24,6 @@ class DpSgdSequential(SequentialBase): def __init__( self, layers, - backprop_context_fn, - noise_context_fn, - labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", - features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", - noise_multiplier=1.0, - disable_encryption=False, - disable_masking=False, - disable_noise=False, - check_overflow_INSECURE=False, - cache_path=None, - jacobian_pfor=False, - jacobian_pfor_iterations=None, *args, **kwargs, ): @@ -52,31 +40,6 @@ def __init__( # be a no-op. self.layers[-1].activation_deriv = None - self.backprop_context_fn = backprop_context_fn - self.noise_context_fn = noise_context_fn - self.labels_party_dev = labels_party_dev - self.features_party_dev = features_party_dev - self.noise_multiplier = noise_multiplier - self.disable_encryption = disable_encryption - self.disable_masking = disable_masking - self.disable_noise = disable_noise - self.check_overflow_INSECURE = check_overflow_INSECURE - self.cache_path = cache_path - self.jacobian_pfor = jacobian_pfor - self.jacobian_pfor_iterations = jacobian_pfor_iterations - - def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs): - super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs) - - if shell_loss is None: - raise ValueError("shell_loss must be provided") - self.loss_fn = shell_loss - - # Keras ignores metrics that are not used during training. When training - # with encryption, the metrics are not updated. Store the metrics so - # they can be recovered during validation. - self.val_metrics = metrics - def call(self, x, training=False): for l in self.layers: x = l(x, training=training) @@ -103,7 +66,7 @@ def compute_max_two_norm_and_pred(self, features, skip_two_norm): return y_pred, max_two_norm - def train_step(self, data): + def shell_train_step(self, data): x, y = data with tf.device(self.labels_party_dev): @@ -281,36 +244,32 @@ def rebalance(x, s): # Apply the gradients to the model. self.optimizer.apply_gradients(zip(grads, self.weights)) - # Do not update metrics during secure training. - if self.disable_encryption: - # Update metrics (includes the metric that tracks the loss) - for metric in self.metrics: - if metric.name == "loss": + for metric in self.metrics: + if metric.name == "loss": + if self.disable_encryption: + metric.update_state(0) # Loss is not known when encrypted. + else: loss = self.loss_fn(y, y_pred) metric.update_state(loss) - else: + else: + if self.disable_encryption: metric.update_state(y, y_pred) + else: + zeros = tf.broadcast_to(0, tf.shape(y_pred)) + metric.update_state( + zeros, zeros + ) # No other metrics when encrypted. + + metric_results = {m.name: m.result() for m in self.metrics} + + # TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics + # passed to compile in it's own dictionary. Keras wants all metrics to + # be returned as a flat dictionary. Here we flatten the dictionary. + result = {} + for key, value in metric_results.items(): + if isinstance(value, dict): + result.update(value) # add subdict directly into the dict + else: + result[key] = value # non-subdict elements are just copied - metric_results = {m.name: m.result() for m in self.metrics} - else: - metric_results = {"num_slots": backprop_context.num_slots} - - return metric_results - - def test_step(self, data): - x, y = data - - # Forward pass. - y_pred = self(x, training=False) - - # Updates the metrics tracking the loss. - self.compute_loss(y=y, y_pred=y_pred) - - # Update the other metrics. - for metric in self.val_metrics: - if metric.name != "loss" and metric.name != "num_slots": - metric.update_state(y, y_pred) - - # Return a dict mapping metric names to current value. - # Note that it will include the loss (tracked in self.metrics). - return {m.name: m.result() for m in self.val_metrics} + return result, None if self.disable_encryption else backprop_context.num_slots diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index ea26e1d..17c113c 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -16,6 +16,7 @@ import tensorflow as tf import tensorflow.keras as keras import tf_shell +import tf_shell_ml class SequentialBase(keras.Sequential): @@ -23,15 +24,67 @@ class SequentialBase(keras.Sequential): def __init__( self, layers, + backprop_context_fn, + noise_context_fn, + labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", + features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", + noise_multiplier=1.0, + cache_path=None, + jacobian_pfor=False, + jacobian_pfor_iterations=None, + disable_encryption=False, + disable_masking=False, + disable_noise=False, + check_overflow_INSECURE=False, *args, **kwargs, ): super().__init__(layers, *args, **kwargs) + self.backprop_context_fn = backprop_context_fn + self.noise_context_fn = noise_context_fn + self.labels_party_dev = labels_party_dev + self.features_party_dev = features_party_dev + self.noise_multiplier = noise_multiplier + self.cache_path = cache_path + self.jacobian_pfor = jacobian_pfor + self.jacobian_pfor_iterations = jacobian_pfor_iterations + self.disable_encryption = disable_encryption + self.disable_masking = disable_masking + self.disable_noise = disable_noise + self.check_overflow_INSECURE = check_overflow_INSECURE self.dataset_prepped = False + def compile(self, shell_loss, **kwargs): + if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy): + raise ValueError( + "The model must be used with the tf-shell version of CategoricalCrossentropy loss function. Saw", + shell_loss, + ) + if len(self.layers) > 0 and not ( + self.layers[-1].activation is tf.keras.activations.softmax + or self.layers[-1].activation is tf.nn.softmax + ): + raise ValueError( + "The model must have a softmax activation function on the final layer. Saw", + self.layers[-1].activation, + ) + + if shell_loss is None: + raise ValueError("shell_loss must be provided") + self.loss_fn = shell_loss + + super().compile(loss=tf.keras.losses.CategoricalCrossentropy(), **kwargs) + + def train_step(self, data): + metrics, num_slots = self.shell_train_step(data) + return metrics + @tf.function def train_step_tf_func(self, data): - return self.train_step(data) + return self.shell_train_step(data) + + def shell_train_step(self, data): + raise NotImplementedError() # Should be overloaded by the subclass. # Prepare the dataset for training with encryption by setting the batch size # to the same value as the encryption ring degree. Run the training loop once @@ -43,21 +96,9 @@ def prep_dataset_for_model(self, train_dataset): # Run the training loop once on dummy data to figure out the batch size. tf.config.run_functions_eagerly(False) - metrics = self.train_step_tf_func(next(iter(train_dataset))) + metrics, num_slots = self.train_step_tf_func(next(iter(train_dataset))) - if not isinstance(metrics, dict): - raise ValueError( - f"Expected train_step to return a dict, got {type(metrics)}." - ) - - if "num_slots" not in metrics: - raise ValueError( - f"Expected train_step to return a dict with key 'num_slots', got {metrics.keys()}." - ) - - train_dataset = train_dataset.rebatch( - metrics["num_slots"].numpy(), drop_remainder=True - ) + train_dataset = train_dataset.rebatch(num_slots.numpy(), drop_remainder=True) self.dataset_prepped = True return train_dataset @@ -65,9 +106,11 @@ def prep_dataset_for_model(self, train_dataset): # Prepare the dataset for training with encryption by setting the batch size # to the same value as the encryption ring degree. It is faster than # `prep_dataset_for_model` because it does not execute the graph, instead - # tracing and optimizing the graph and extracting the required parameters. + # tracing and optimizing the graph and extracting the required parameters + # without actually executing the graph. def fast_prep_dataset_for_model(self, train_dataset): - if not self.disable_encryption: + if self.disable_encryption: + self.dataset_prepped = True return train_dataset # Call the training step with keygen to trace the graph. Use a copy @@ -101,7 +144,7 @@ def get_tensor_by_name(g, name): log_n = get_tensor_by_name(optimized_graph, context_node.input[0]).tolist() - train_dataset = train_dataset.rebatch(2**log_n, drop_remainder=True) + train_dataset = train_dataset.unbatch().batch(2**log_n, drop_remainder=True) self.dataset_prepped = True return train_dataset diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index 7d503e6..a861461 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -23,18 +23,6 @@ class PostScaleSequential(SequentialBase): def __init__( self, layers, - backprop_context_fn, - noise_context_fn, - labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", - features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", - noise_multiplier=1.0, - disable_encryption=False, - disable_masking=False, - disable_noise=False, - check_overflow_INSECURE=False, - cache_path=None, - jacobian_pfor=False, - jacobian_pfor_iterations=None, *args, **kwargs, ): @@ -51,38 +39,7 @@ def __init__( # be a no-op. self.layers[-1].activation_deriv = None - self.backprop_context_fn = backprop_context_fn - self.noise_context_fn = noise_context_fn - self.labels_party_dev = labels_party_dev - self.features_party_dev = features_party_dev - self.clipping_threshold = 10000000 - self.context_prepped = False - self.disable_encryption = disable_encryption - self.noise_multiplier = noise_multiplier - self.disable_masking = disable_masking - self.disable_noise = disable_noise - self.check_overflow_INSECURE = check_overflow_INSECURE - self.cache_path = cache_path - self.jacobian_pfor = jacobian_pfor - self.jacobian_pfor_iterations = jacobian_pfor_iterations - - def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs): - super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs) - - if shell_loss is None: - raise ValueError("shell_loss must be provided") - self.loss_fn = shell_loss - - # Keras ignores metrics that are not used during training. When training - # with encryption, the metrics are not updated. Store the metrics so - # they can be recovered during validation. - self.val_metrics = metrics - - def _unpack(self, x_list): - batch_size = tf.shape(x_list[0])[0] // 2 - return [x[0] + x[batch_size] for x in x_list] - - def train_step(self, data): + def shell_train_step(self, data): x, y = data with tf.device(self.labels_party_dev): @@ -157,10 +114,7 @@ def train_step(self, data): # Note, checking the backprop gradients requires decryption # on the features party which breaks security of the protocol. bp_scaling_factors = [g._scaling_factor for g in grads] - dec_grads = [ - tf_shell.to_tensorflow(g, backprop_secret_fastrot_key) - for g in grads - ] + dec_grads = [tf_shell.to_tensorflow(g, secret_key) for g in grads] self.warn_on_overflow( dec_grads, bp_scaling_factors, @@ -292,18 +246,32 @@ def rebalance(x, s): # Apply the gradients to the model. self.optimizer.apply_gradients(zip(grads, self.weights)) - # Do not update metrics during secure training. - if self.disable_encryption: - # Update metrics (includes the metric that tracks the loss) - for metric in self.metrics: - if metric.name == "loss": + for metric in self.metrics: + if metric.name == "loss": + if self.disable_encryption: + metric.update_state(0) # Loss is not known when encrypted. + else: loss = self.loss_fn(y, y_pred) metric.update_state(loss) - else: + else: + if self.disable_encryption: metric.update_state(y, y_pred) + else: + zeros = tf.broadcast_to(0, tf.shape(y_pred)) + metric.update_state( + zeros, zeros + ) # No other metrics when encrypted. + + metric_results = {m.name: m.result() for m in self.metrics} + + # TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics + # passed to compile in it's own dictionary. Keras wants all metrics to + # be returned as a flat dictionary. Here we flatten the dictionary. + result = {} + for key, value in metric_results.items(): + if isinstance(value, dict): + result.update(value) # add subdict directly into the dict + else: + result[key] = value # non-subdict elements are just copied - metric_results = {m.name: m.result() for m in self.metrics} - else: - metric_results = {"num_slots": backprop_context.num_slots} - - return metric_results + return result, None if self.disable_encryption else backprop_context.num_slots diff --git a/tf_shell_ml/test/dpsgd_conv_model_local_test.py b/tf_shell_ml/test/dpsgd_conv_model_local_test.py index 3c9ba9a..e148b25 100644 --- a/tf_shell_ml/test/dpsgd_conv_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_conv_model_local_test.py @@ -78,7 +78,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) ), ], backprop_context_fn=lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=22, + log2_cleartext_sz=18, scaling_factor=2, noise_offset_log2=-20, cache_path=cache, @@ -93,7 +93,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) disable_masking=disable_masking, disable_noise=disable_noise, cache_path=cache, - check_overflow_INSECURE=True, + # check_overflow_INSECURE=True, # jacobian_pfor=True, # jacobian_pfor_iterations=128, ) @@ -101,23 +101,28 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) m.compile( shell_loss=tf_shell_ml.CategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(0.1), - loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()], ) m.build([None, 28, 28, 1]) m.summary() - history = m.fit(train_dataset.take(8), epochs=1, validation_data=val_dataset) + history = m.fit( + train_dataset, + steps_per_epoch=8, + epochs=1, + verbose=2, + validation_data=val_dataset, + ) self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.13) def test_model(self): with tempfile.TemporaryDirectory() as cache_dir: self._test_model(False, False, False, cache_dir) - # self._test_model(True, False, False, cache_dir) - # self._test_model(False, True, False, cache_dir) - # self._test_model(False, False, True, cache_dir) + self._test_model(True, False, False, cache_dir) + self._test_model(False, True, False, cache_dir) + self._test_model(False, False, True, cache_dir) if __name__ == "__main__": diff --git a/tf_shell_ml/test/dpsgd_model_distrib_test.py b/tf_shell_ml/test/dpsgd_model_distrib_test.py index 2125e01..34eeca7 100644 --- a/tf_shell_ml/test/dpsgd_model_distrib_test.py +++ b/tf_shell_ml/test/dpsgd_model_distrib_test.py @@ -110,20 +110,19 @@ def test_model(self): labels_party_dev=labels_party_dev, features_party_dev=features_party_dev, cache_path=cache, - # check_overflow_INSECURE=True, - # disable_noise=True, ) m.compile( shell_loss=tf_shell_ml.CategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(0.1), - loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()], ) history = m.fit( - train_dataset.take(4), + train_dataset, + steps_per_epoch=8, epochs=1, + verbose=2, validation_data=val_dataset, ) diff --git a/tf_shell_ml/test/dpsgd_model_local_test.py b/tf_shell_ml/test/dpsgd_model_local_test.py index 49575e8..b70cde0 100644 --- a/tf_shell_ml/test/dpsgd_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_model_local_test.py @@ -74,14 +74,19 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) m.compile( shell_loss=tf_shell_ml.CategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(0.1), - loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()], ) m.build([None, 784]) m.summary() - history = m.fit(train_dataset.take(4), epochs=1, validation_data=val_dataset) + history = m.fit( + train_dataset, + steps_per_epoch=8, + epochs=1, + verbose=2, + validation_data=val_dataset, + ) self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.30) diff --git a/tf_shell_ml/test/dropout_test.py b/tf_shell_ml/test/dropout_test.py index d5816ec..7a2da60 100644 --- a/tf_shell_ml/test/dropout_test.py +++ b/tf_shell_ml/test/dropout_test.py @@ -66,14 +66,14 @@ def _test_dropout_forward(self, per_batch): enc_x = tf_shell.to_encrypted(x, key, context) dropout_layer = tf_shell_ml.ShellDropout(0.2, per_batch=per_batch) - notrain_enc_y = dropout_layer(enc_x, training=False) + notrain_enc_y = dropout_layer.call(enc_x, training=False) self.assertAllClose( tf_shell.to_tensorflow(notrain_enc_y, key), x, atol=1 / context.scaling_factor, ) - enc_train_y = dropout_layer(enc_x, training=True) + enc_train_y = dropout_layer.call(enc_x, training=True) dec_train_y = tf_shell.to_tensorflow(enc_train_y, key) self.assertLess( tf.math.count_nonzero(dec_train_y), tf.size(dec_train_y, out_type=tf.int64) @@ -84,7 +84,7 @@ def _test_dropout_back(self, per_batch): dropout_layer = tf_shell_ml.ShellDropout(0.2, per_batch=per_batch) - notrain_y = dropout_layer(x, training=True) + notrain_y = dropout_layer.call(x, training=True) dy = tf.ones_like(notrain_y) dw, dx = dropout_layer.backward(dy, None) diff --git a/tf_shell_ml/test/postscale_model_local_test.py b/tf_shell_ml/test/postscale_model_local_test.py index c4bc249..6d51a85 100644 --- a/tf_shell_ml/test/postscale_model_local_test.py +++ b/tf_shell_ml/test/postscale_model_local_test.py @@ -46,7 +46,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): m = tf_shell_ml.PostScaleSequential( [ tf.keras.layers.Dense(64, activation="relu"), - tf.keras.layers.Dense(10, activation="sigmoid"), + tf.keras.layers.Dense(10, activation="softmax"), ], lambda: tf_shell.create_autocontext64( log2_cleartext_sz=23, @@ -69,11 +69,16 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): m.compile( shell_loss=tf_shell_ml.CategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(0.1), - loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()], ) - history = m.fit(train_dataset.take(4), epochs=1, validation_data=val_dataset) + history = m.fit( + train_dataset, + steps_per_epoch=8, + epochs=1, + verbose=2, + validation_data=val_dataset, + ) self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25)