diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index 135962d..6cbb9e3 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -33,17 +33,18 @@ def __init__( self.layers[0].is_first_layer = True # Do not set the derivative of the activation function for the last # layer in the model. The derivative of the categorical crossentropy - # loss function times the derivative of a softmax is just y_pred - y - # (which is much easier to compute than each of them individually). - # So instead just let the loss function derivative incorporate - # y_pred - y and let the derivative of this last layer's activation - # be a no-op. + # loss function times the derivative of a softmax is just y_pred - + # labels (which is much easier to compute than each of them + # individually). So instead just let the loss function derivative + # incorporate y_pred - labels and let the derivative of this last + # layer's activation be a no-op. self.layers[-1].activation_deriv = None - def call(self, x, training=False): + def call(self, features, training=False): + out = features for l in self.layers: - x = l(x, training=training) - return x + out = l(out, training=training) + return out def compute_max_two_norm_and_pred(self, features, skip_two_norm): with tf.GradientTape( @@ -68,26 +69,28 @@ def compute_max_two_norm_and_pred(self, features, skip_two_norm): return y_pred, max_two_norm - def shell_train_step(self, data): - x, y = data - + def shell_train_step(self, features, labels): with tf.device(self.labels_party_dev): if self.disable_encryption: - enc_y = y + enc_y = labels else: backprop_context = self.backprop_context_fn() backprop_secret_key = tf_shell.create_key64( backprop_context, self.cache_path ) - # Encrypt the batch of secret labels y. - enc_y = tf_shell.to_encrypted(y, backprop_secret_key, backprop_context) + # Encrypt the batch of secret labels. + enc_y = tf_shell.to_encrypted( + labels, backprop_secret_key, backprop_context + ) - with tf.device(self.features_party_dev): + with tf.device(self.jacobian_device): + features = tf.identity(features) # copy to GPU if needed # Forward pass in plaintext. y_pred, max_two_norm = self.compute_max_two_norm_and_pred( - x, self.disable_noise + features, self.disable_noise ) + with tf.device(self.features_party_dev): # Backward pass. dx = self.loss_fn.grad(enc_y, y_pred) dJ_dw = [] # Derivatives of the loss with respect to the weights. @@ -199,7 +202,7 @@ def shell_train_step(self, data): # secret key. flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key) - if not self.disable_encryption and self.check_overflow_INSECURE: + if self.check_overflow_INSECURE: nosie_scaling_factors = grads._scaling_factor self.warn_on_overflow( [flat_grads], @@ -249,14 +252,14 @@ def rebalance(x, s): for metric in self.metrics: if metric.name == "loss": if self.disable_encryption: - loss = self.loss_fn(y, y_pred) + loss = self.loss_fn(labels, y_pred) metric.update_state(loss) else: # Loss is unknown when encrypted. metric.update_state(0.0) else: if self.disable_encryption: - metric.update_state(y, y_pred) + metric.update_state(labels, y_pred) else: # Other metrics are uknown when encrypted. zeros = tf.broadcast_to(0, tf.shape(y_pred)) diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index 82beba8..f055752 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -15,8 +15,10 @@ # limitations under the License. import tensorflow as tf import tensorflow.keras as keras +from tensorflow.keras.callbacks import CallbackList import tf_shell import tf_shell_ml +import time class SequentialBase(keras.Sequential): @@ -32,6 +34,7 @@ def __init__( cache_path=None, jacobian_pfor=False, jacobian_pfor_iterations=None, + jacobian_device=None, disable_encryption=False, disable_masking=False, disable_noise=False, @@ -48,6 +51,9 @@ def __init__( self.cache_path = cache_path self.jacobian_pfor = jacobian_pfor self.jacobian_pfor_iterations = jacobian_pfor_iterations + self.jacobian_device = ( + features_party_dev if jacobian_device is None else jacobian_device + ) self.disable_encryption = disable_encryption self.disable_masking = disable_masking self.disable_noise = disable_noise @@ -59,12 +65,6 @@ def __init__( "WARNING: `jacobian_pfor` may be incompatible with `disable_encryption`." ) - def compile(self, shell_loss, **kwargs): - if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy): - raise ValueError( - "The model must be used with the tf-shell version of CategoricalCrossentropy loss function. Saw", - shell_loss, - ) if len(self.layers) > 0 and not ( self.layers[-1].activation is tf.keras.activations.softmax or self.layers[-1].activation is tf.nn.softmax @@ -74,54 +74,85 @@ def compile(self, shell_loss, **kwargs): self.layers[-1].activation, ) - if shell_loss is None: - raise ValueError("shell_loss must be provided") + def compile(self, shell_loss, **kwargs): + if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy): + raise ValueError( + "The model must be used with the tf_shell_ml version of CategoricalCrossentropy loss function. Saw", + shell_loss, + ) self.loss_fn = shell_loss - super().compile(loss=tf.keras.losses.CategoricalCrossentropy(), **kwargs) + super().compile( + loss=tf.keras.losses.CategoricalCrossentropy(), + jit_compile=False, # Disable XLA, no CPU op for tf_shell_ml's TensorArrayV2. + **kwargs, + ) - def train_step(self, data): - metrics, num_slots = self.shell_train_step(data) + def train_step(self, features, labels): + metrics, num_slots = self.shell_train_step(features, labels) return metrics @tf.function - def train_step_tf_func(self, data): - return self.shell_train_step(data) + def train_step_with_keygen(self, features, labels): + return self.shell_train_step(features, labels) - def shell_train_step(self, data): + @tf.function + def train_step_tf_func(self, features, labels): + return self.shell_train_step(features, labels) + + def shell_train_step(self, features, labels): raise NotImplementedError() # Should be overloaded by the subclass. - # Prepare the dataset for training with encryption by setting the batch size - # to the same value as the encryption ring degree. Run the training loop once - # on dummy data to figure out the batch size. - def prep_dataset_for_model(self, train_dataset): + def prep_dataset_for_model(self, train_features, train_labels): + """Prepare the dataset for training with encryption by setting the batch + size to the same value as the encryption ring degree. Run the training + loop once on dummy data to figure out the batch size. + """ if self.disable_encryption: + self.batch_size = next(iter(train_features)).shape[0] self.dataset_prepped = True - return train_dataset + return train_features, train_labels # Run the training loop once on dummy data to figure out the batch size. - tf.config.run_functions_eagerly(False) - metrics, num_slots = self.train_step_tf_func(next(iter(train_dataset))) + # Use a separate tf.function to avoid caching the trace so keys and + # context are written to cache and read on next trace. + metrics, num_slots = self.train_step_with_keygen( + next(iter(train_features)), next(iter(train_labels)) + ) + + self.batch_size = num_slots.numpy() - train_dataset = train_dataset.rebatch(num_slots.numpy(), drop_remainder=True) + with tf.device(self.features_party_dev): + train_features = train_features.rebatch( + num_slots.numpy(), drop_remainder=True + ) + with tf.device(self.labels_party_dev): + train_labels = train_labels.rebatch(num_slots.numpy(), drop_remainder=True) self.dataset_prepped = True - return train_dataset - - # Prepare the dataset for training with encryption by setting the batch size - # to the same value as the encryption ring degree. It is faster than - # `prep_dataset_for_model` because it does not execute the graph, instead - # tracing and optimizing the graph and extracting the required parameters - # without actually executing the graph. - def fast_prep_dataset_for_model(self, train_dataset): + return train_features, train_labels + + def fast_prep_dataset_for_model(self, train_features, train_labels): + """Prepare the dataset for training with encryption by setting the + batch size to the same value as the encryption ring degree. It is faster + than `prep_dataset_for_model` because it does not execute the graph, + instead tracing and optimizing the graph and extracting the required + parameters without actually executing the graph. + + Since the graph is not executed, caches for keys and the shell context + are not written to disk. + """ if self.disable_encryption: + self.batch_size = next(iter(train_features)).shape[0] self.dataset_prepped = True - return train_dataset + return train_features, train_labels - # Call the training step with keygen to trace the graph. Use a copy - # of the function to avoid caching the trace. - traceable_copy = self.train_step_tf_func - func = traceable_copy.get_concrete_function(next(iter(train_dataset))) + # Call the training step with keygen to trace the graph. Use a copy of + # the function to avoid caching the trace so keys and context are + # written to cache and read on next trace. + func = self.train_step_with_keygen.get_concrete_function( + next(iter(train_features)), next(iter(train_labels)) + ) # Optimize the graph using tf_shells HE-specific optimizers. optimized_func = tf_shell.optimize_shell_graph( @@ -148,12 +179,30 @@ def get_tensor_by_name(g, name): raise ValueError(f"Node {name} not found in graph.") log_n = get_tensor_by_name(optimized_graph, context_node.input[0]).tolist() + self.batch_size = 2**log_n + + with tf.device(self.features_party_dev): + train_features = train_features.rebatch(2**log_n, drop_remainder=True) + with tf.device(self.labels_party_dev): + train_labels = train_labels.rebatch(2**log_n, drop_remainder=True) - train_dataset = train_dataset.unbatch().batch(2**log_n, drop_remainder=True) self.dataset_prepped = True - return train_dataset + return train_features, train_labels - def fit(self, train_dataset, **kwargs): + def fit( + self, + features_dataset, + labels_dataset, + epochs=1, + batch_size=32, + callbacks=None, + validation_data=None, + steps_per_epoch=None, + verbose=1, + ): + """A custom training loop that supports inputs from multiple datasets, + each of which can be on a different device. + """ # Prevent TensorFlow from placing ops on devices which were not # explicitly assigned for security reasons. tf.config.set_soft_device_placement(False) @@ -162,9 +211,93 @@ def fit(self, train_dataset, **kwargs): tf_shell.enable_optimization() if not self.dataset_prepped: - train_dataset = self.prep_dataset_for_model(train_dataset) + features_dataset, labels_dataset = self.prep_dataset_for_model( + features_dataset, labels_dataset + ) - return super().fit(train_dataset, **kwargs) + # Calculate samples if possible. + if steps_per_epoch is None: + samples = None + else: + samples = steps_per_epoch * self.batch_size + + # Initialize callbacks. + callback_list = CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + model=self, + batch_size=self.batch_size, + epochs=epochs, + steps=steps_per_epoch, + samples=samples, + verbose=verbose, + do_validation=validation_data is not None, + metrics=list(self.metrics_names), + ) + + # Begin training. + callback_list.on_train_begin() + logs = {} + + for epoch in range(epochs): + callback_list.on_epoch_begin(epoch, logs) + start_time = time.time() + self.reset_metrics() + + # Training loop. + for step, (batch_x, batch_y) in enumerate( + zip(features_dataset, labels_dataset) + ): + callback_list.on_train_batch_begin(step, logs) + logs, num_slots = self.train_step_tf_func(batch_x, batch_y) + callback_list.on_train_batch_end(step, logs) + if steps_per_epoch is not None and step + 1 >= steps_per_epoch: + break + + # Validation loop. + if validation_data is not None: + # Reset metrics + self.reset_metrics() + + for val_x_batch, val_y_batch in validation_data: + val_y_pred = self(val_x_batch, training=False) + # Update validation metrics + for m in self.metrics: + if m.name == "loss": + loss = self.loss_fn(val_y_batch, val_y_pred) + m.update_state(loss) + else: + m.update_state(val_y_batch, val_y_pred) + metric_results = {m.name: m.result() for m in self.metrics} + + # TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds + # metrics passed to compile in it's own dictionary. Keras wants + # all metrics to be returned as a flat dictionary. Here we + # flatten the dictionary. + result = {} + for key, value in metric_results.items(): + if isinstance(value, dict): + result.update(value) # add subdict directly into the dict + else: + result[key] = value # non-subdict elements are just copied + + logs.update({f"val_{name}": result for name, result in result.items()}) + + # End of epoch. + logs["time"] = time.time() - start_time + + # Update the steps in callback parameters with actual steps completed + if steps_per_epoch is None: + steps_per_epoch = step + 1 + samples = steps_per_epoch * self.batch_size + callback_list.params["steps"] = steps_per_epoch + callback_list.params["samples"] = samples + callback_list.on_epoch_end(epoch, logs) + + # End of training. + callback_list.on_train_end(logs) + return self.history def flatten_jacobian_list(self, grads): """Takes as input a jacobian and flattens into a single tensor. The @@ -180,7 +313,7 @@ def flatten_jacobian_list(self, grads): # Get the shapes from TensorFlow's tensors, not SHELL's context for when # the batch size != slotting dim or not using encryption. slot_size = tf.shape(grads[0])[0] - num_output_classes = grads[0].shape[1] + num_output_classes = tf.shape(grads[0])[1] grad_shapes = [g.shape[2:] for g in grads] flattened_grad_shapes = [s.num_elements() for s in grad_shapes] @@ -292,9 +425,10 @@ def unflatten_and_unpad_grad( return grads_list def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message): - # If the gradient is between [-t/2, -t/4] or [t/4, t/2], the gradient - # may have overflowed. This also must take the scaling factor into - # account so the range is divided by the scaling factor. + """If the gradient is between [-t/2, -t/4] or [t/4, t/2], the gradient + may have overflowed. This also must take the scaling factor into account + so the range is divided by the scaling factor. + """ t = tf.cast(plaintext_modulus, grads[0].dtype) t_half = t / 2 diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index c87a493..71fb545 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -30,37 +30,39 @@ def __init__( if len(self.layers) > 0: self.layers[0].is_first_layer = True - # Do not set the derivative of the activation function for the last - # layer in the model. The derivative of the categorical crossentropy - # loss function times the derivative of a softmax is just y_pred - y - # (which is much easier to compute than each of them individually). - # So instead just let the loss function derivative incorporate - # y_pred - y and let the derivative of this last layer's activation - # be a no-op. - self.layers[-1].activation_deriv = None - - def shell_train_step(self, data): - x, y = data + # Override the activation of the last layer. Set it to linear so + # when the jacobian is computed, the derivative of the activation + # function is a no-op. This is required for the post scale protocol. + self.layers[-1].activation = tf.keras.activations.linear + + def call(self, inputs, training=False, with_softmax=True): + prediction = super().call(inputs, training) + if not with_softmax: + return prediction + # Perform the last layer activation since it is removed for training + # purposes. + return tf.nn.softmax(prediction) + + def shell_train_step(self, features, labels): with tf.device(self.labels_party_dev): if self.disable_encryption: - enc_y = y + enc_y = labels else: backprop_context = self.backprop_context_fn() secret_key = tf_shell.create_key64(backprop_context, self.cache_path) - # Encrypt the batch of secret labels y. - enc_y = tf_shell.to_encrypted(y, secret_key, backprop_context) + # Encrypt the batch of secret labels. + enc_y = tf_shell.to_encrypted(labels, secret_key, backprop_context) - with tf.device(self.features_party_dev): - # Unset the activation function for the last layer so it is not used in - # computing the gradient. The effect of the last layer activation function - # is factored out of the gradient computation and accounted for below. - self.layers[-1].activation = tf.keras.activations.linear + with tf.device(self.jacobian_device): + features = tf.identity(features) # copy to GPU if needed + # self.layers[-1].activation = tf.keras.activations.linear with tf.GradientTape( persistent=tf.executing_eagerly() or self.jacobian_pfor ) as tape: - y_pred = self(x, training=True) # forward pass + y_pred = self(features, training=True, with_softmax=False) + grads = tape.jacobian( y_pred, self.trainable_variables, @@ -71,13 +73,12 @@ def shell_train_step(self, data): # ^ layers list x (batch size x num output classes x weights) matrix # dy_pred_j/dW_sample_class - # Reset the activation function for the last layer and compute the real - # prediction. - self.layers[-1].activation = tf.keras.activations.sigmoid - y_pred = self(x, training=False) + # compute the + # activation manually. + y_pred = tf.nn.softmax(y_pred) - # Compute y_pred - y (where y may be encrypted). - # scalars = y_pred - y # dJ/dy_pred + with tf.device(self.features_party_dev): + # Compute prediction - labels (where labels may be encrypted). scalars = enc_y.__rsub__(y_pred) # dJ/dy_pred # ^ batch_size x num output classes. @@ -162,13 +163,13 @@ def shell_train_step(self, data): grads = [tf_shell.to_tensorflow(g, secret_key) for g in grads] # Sum the masked gradients over the batch. - if self.disable_masking or self.disable_encryption: - grads = [tf.reduce_sum(g, 0) for g in grads] - else: + if not self.disable_masking and not self.disable_encryption: grads = [ tf_shell.reduce_sum_with_mod(g, 0, backprop_context, s) for g, s in zip(grads, mask_scaling_factors) ] + else: + grads = [tf.reduce_sum(g, 0) for g in grads] if not self.disable_noise: if not self.disable_encryption: @@ -202,7 +203,7 @@ def shell_train_step(self, data): # secret key. flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key) - if not self.disable_encryption and self.check_overflow_INSECURE: + if self.check_overflow_INSECURE: nosie_scaling_factor = grads._scaling_factor self.warn_on_overflow( [flat_grads], @@ -250,14 +251,14 @@ def rebalance(x, s): for metric in self.metrics: if metric.name == "loss": if self.disable_encryption: - loss = self.loss_fn(y, y_pred) + loss = self.loss_fn(labels, y_pred) metric.update_state(loss) else: # Loss is unknown when encrypted. metric.update_state(0.0) else: if self.disable_encryption: - metric.update_state(y, y_pred) + metric.update_state(labels, y_pred) else: # Other metrics are uknown when encrypted. zeros = tf.broadcast_to(0, tf.shape(y_pred)) diff --git a/tf_shell_ml/test/BUILD b/tf_shell_ml/test/BUILD index 22e20d2..64a8b19 100644 --- a/tf_shell_ml/test/BUILD +++ b/tf_shell_ml/test/BUILD @@ -115,3 +115,14 @@ py_test( requirement("tensorflow"), ], ) + +py_test( + name = "postscale_model_distrib_test", + size = "enormous", + srcs = ["postscale_model_distrib_test.py"], + tags = ["exclusive"], + deps = [ + "//tf_shell_ml", + requirement("tensorflow"), + ], +) diff --git a/tf_shell_ml/test/dpsgd_conv_model_local_test.py b/tf_shell_ml/test/dpsgd_conv_model_local_test.py index 7144144..97672b2 100644 --- a/tf_shell_ml/test/dpsgd_conv_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_conv_model_local_test.py @@ -34,10 +34,14 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) # Clip dataset images to limit memory usage. The model accuracy will be # bad but this test only measures functionality. - # x_train, x_test = x_train[:, :512], x_test[:, :512] + # x_train, x_test = x_train[:, 5:23, 5:23, :], x_test[:, 5:23, 5:23, :] + # print(x_train.shape, x_test.shape) - train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(2**13) + labels_dataset = tf.data.Dataset.from_tensor_slices(y_train) + labels_dataset = labels_dataset.batch(2**10) + + features_dataset = tf.data.Dataset.from_tensor_slices(x_train) + features_dataset = features_dataset.batch(2**10) val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) val_dataset = val_dataset.batch(32) @@ -106,7 +110,8 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) m.summary() history = m.fit( - train_dataset, + features_dataset, + labels_dataset, steps_per_epoch=8, epochs=1, verbose=2, diff --git a/tf_shell_ml/test/dpsgd_model_distrib_test.py b/tf_shell_ml/test/dpsgd_model_distrib_test.py index 34eeca7..d6af1df 100644 --- a/tf_shell_ml/test/dpsgd_model_distrib_test.py +++ b/tf_shell_ml/test/dpsgd_model_distrib_test.py @@ -21,7 +21,6 @@ job_prefix = "tfshell" features_party_job = f"{job_prefix}features" labels_party_job = f"{job_prefix}labels" -coordinator_party_job = f"{job_prefix}coordinator" features_party_dev = f"/job:{features_party_job}/replica:0/task:0/device:CPU:0" labels_party_dev = f"/job:{labels_party_job}/replica:0/task:0/device:CPU:0" @@ -38,7 +37,6 @@ def test_model(self): cluster = tf.train.ClusterSpec( { - f"{coordinator_party_job}": ["localhost:2222"], f"{features_party_job}": ["localhost:2223"], f"{labels_party_job}": ["localhost:2224"], } @@ -55,7 +53,7 @@ def test_model(self): # Register the tf-shell ops. import tf_shell - if self.job_name != coordinator_party_job: + if self.job_name == labels_party_job: print(f"{self.job_name} server started.", flush=True) server.join() return @@ -74,53 +72,59 @@ def test_model(self): # bad but this test only measures functionality. x_train, x_test = x_train[:, :250], x_test[:, :250] - train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**14).batch(2**12) - - val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) - val_dataset = val_dataset.batch(32) - - cache_dir = tempfile.TemporaryDirectory() - cache = cache_dir.name - - m = tf_shell_ml.DpSgdSequential( - [ - tf_shell_ml.ShellDense( - 64, - activation=tf_shell_ml.relu, - activation_deriv=tf_shell_ml.relu_deriv, + with tf.device(labels_party_dev): + labels_dataset = tf.data.Dataset.from_tensor_slices(y_train) + labels_dataset = labels_dataset.batch(2**10) + + with tf.device(features_party_dev): + features_dataset = tf.data.Dataset.from_tensor_slices(x_train) + features_dataset = features_dataset.batch(2**10) + + val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + val_dataset = val_dataset.batch(32) + + cache_dir = tempfile.TemporaryDirectory() + cache = cache_dir.name + + m = tf_shell_ml.DpSgdSequential( + [ + tf_shell_ml.ShellDense( + 64, + activation=tf_shell_ml.relu, + activation_deriv=tf_shell_ml.relu_deriv, + ), + tf_shell_ml.ShellDense( + 10, + activation=tf.nn.softmax, + ), + ], + lambda: tf_shell.create_autocontext64( + log2_cleartext_sz=23, + scaling_factor=32, + noise_offset_log2=14, + cache_path=cache, ), - tf_shell_ml.ShellDense( - 10, - activation=tf.nn.softmax, + lambda: tf_shell.create_autocontext64( + log2_cleartext_sz=24, + scaling_factor=1, + noise_offset_log2=0, + cache_path=cache, ), - ], - lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=23, - scaling_factor=32, - noise_offset_log2=14, + labels_party_dev=labels_party_dev, + features_party_dev=features_party_dev, cache_path=cache, - ), - lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=24, - scaling_factor=1, - noise_offset_log2=0, - cache_path=cache, - ), - labels_party_dev=labels_party_dev, - features_party_dev=features_party_dev, - cache_path=cache, - ) + ) - m.compile( - shell_loss=tf_shell_ml.CategoricalCrossentropy(), - optimizer=tf.keras.optimizers.Adam(0.1), - metrics=[tf.keras.metrics.CategoricalAccuracy()], - ) + m.compile( + shell_loss=tf_shell_ml.CategoricalCrossentropy(), + optimizer=tf.keras.optimizers.Adam(0.1), + metrics=[tf.keras.metrics.CategoricalAccuracy()], + ) history = m.fit( - train_dataset, - steps_per_epoch=8, + features_dataset, + labels_dataset, + steps_per_epoch=2, epochs=1, verbose=2, validation_data=val_dataset, @@ -128,7 +132,7 @@ def test_model(self): cache_dir.cleanup() - self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3) + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25) if __name__ == "__main__": @@ -138,15 +142,8 @@ def test_model(self): tf.test.main() os._exit(0) - features_pid = os.fork() - if features_pid == 0: # child process - TestDistribModel.job_name = features_party_job - tf.test.main() - os._exit(0) - - TestDistribModel.job_name = coordinator_party_job + TestDistribModel.job_name = features_party_job tf.test.main() os.waitpid(labels_pid, 0) - os.waitpid(features_pid, 0) print("Both parties finished.") diff --git a/tf_shell_ml/test/dpsgd_model_local_test.py b/tf_shell_ml/test/dpsgd_model_local_test.py index 60baacc..39b9633 100644 --- a/tf_shell_ml/test/dpsgd_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_model_local_test.py @@ -33,10 +33,13 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) # Clip dataset images to limit memory usage. The model accuracy will be # bad but this test only measures functionality. - x_train, x_test = x_train[:, :300], x_test[:, :300] + x_train, x_test = x_train[:, :350], x_test[:, :350] - train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(2**12) + labels_dataset = tf.data.Dataset.from_tensor_slices(y_train) + labels_dataset = labels_dataset.batch(2**10) + + features_dataset = tf.data.Dataset.from_tensor_slices(x_train) + features_dataset = features_dataset.batch(2**10) val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) val_dataset = val_dataset.batch(32) @@ -81,25 +84,28 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) metrics=[tf.keras.metrics.CategoricalAccuracy()], ) - m.build([None, 300]) + m.build([None, 350]) m.summary() history = m.fit( - train_dataset, - steps_per_epoch=8, + features_dataset, + labels_dataset, + steps_per_epoch=2, epochs=1, verbose=2, validation_data=val_dataset, ) - self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.30) + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25) def test_model(self): with tempfile.TemporaryDirectory() as cache_dir: + # Perform full encrypted test to populate cache. self._test_model(False, False, False, cache_dir) self._test_model(True, False, False, cache_dir) self._test_model(False, True, False, cache_dir) self._test_model(False, False, True, cache_dir) + self._test_model(True, True, True, cache_dir) if __name__ == "__main__": diff --git a/tf_shell_ml/test/mnist_post_scale_test.py b/tf_shell_ml/test/mnist_post_scale_test.py index 52c39ec..90ef0ca 100644 --- a/tf_shell_ml/test/mnist_post_scale_test.py +++ b/tf_shell_ml/test/mnist_post_scale_test.py @@ -53,11 +53,12 @@ mnist_layers = [ tf.keras.layers.Dense(64, activation="relu"), - tf.keras.layers.Dense(10, activation="sigmoid"), + tf.keras.layers.Dense(10, activation="softmax"), ] model = keras.Sequential(mnist_layers) model.compile( + loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"], ) @@ -96,8 +97,8 @@ def train_step(x, y): # Reset the activation function for the last layer and compute the real # prediction. - model.layers[-1].activation = tf.keras.activations.sigmoid - y_pred = model(x, training=False) + model.layers[-1].activation = tf.keras.activations.softmax + y_pred = tf.nn.softmax(y_pred) # Compute y_pred - y (where y may be encrypted). # scalars = y_pred - y # dJ/dy_pred @@ -149,7 +150,7 @@ def train_step(x, y): class TestPlaintextPostScale(tf.test.TestCase): - def _test_mnist_post_scale(self, eager_mode): + def _test_mnist_post_scale_eager_vs_deferred(self, eager_mode): tf.config.run_functions_eagerly(eager_mode) (x_batch, y_batch) = next(iter(train_dataset)) @@ -157,12 +158,6 @@ def _test_mnist_post_scale(self, eager_mode): # Plaintext ps_grads = train_step(x_batch, y_batch) - if not eager_mode: - # With autograph on (eagerly off), the tf.function trace cannot be - # reused between plaintext and encrypted calls. Reset the graph - # between plaintext and encrypted train_step() calls. - tf.keras.backend.clear_session() - # Encrypted enc_y_batch = tf_shell.to_encrypted(y_batch, key, context) shell_ps_grads = train_step(x_batch, enc_y_batch) @@ -174,10 +169,42 @@ def _test_mnist_post_scale(self, eager_mode): atol=1 / context.scaling_factor * batch_size, ) - def test_mnist_post_scale(self): + def test_mnist_post_scale_eager_vs_deferred(self): + for eager_mode in [False, True]: + with self.subTest(f"{self._testMethodName} with eager_mode={eager_mode}."): + self._test_mnist_post_scale_eager_vs_deferred(eager_mode) + + def _test_mnist_post_scale_vs_keras(self, eager_mode): + tf.config.run_functions_eagerly(eager_mode) + train_dataset_iter = iter(train_dataset) + + for i in range(8): + (x_batch, y_batch) = next(train_dataset_iter) + # Plaintext post scale. + ps_grads = train_step(x_batch, y_batch) + ps_grads = [tf.reduce_sum(g, axis=0) for g in ps_grads] + + # Keras normal backprop. + with tf.GradientTape() as tape: + y_pred = model(x_batch, training=True) # forward pass + loss = tf.keras.losses.categorical_crossentropy(y_batch, y_pred) + grads = tape.gradient(loss, model.trainable_variables) + + # Compare the gradients. + self.assertAllClose(ps_grads, grads, atol=1e-3) + + # Apply the gradients. + model.optimizer.apply_gradients(zip(ps_grads, model.trainable_variables)) + + # Evaluate the model. + val_loss, val_acc = model.evaluate(val_dataset) + print(f"Validation loss: {val_loss}, accuracy: {val_acc}") + self.assertGreater(val_acc, 0.3) + + def test_mnist_post_scale_eager_vs_deferred(self): for eager_mode in [False, True]: with self.subTest(f"{self._testMethodName} with eager_mode={eager_mode}."): - self._test_mnist_post_scale(eager_mode) + self._test_mnist_post_scale_vs_keras(eager_mode) if __name__ == "__main__": diff --git a/tf_shell_ml/test/postscale_model_distrib_test.py b/tf_shell_ml/test/postscale_model_distrib_test.py new file mode 100644 index 0000000..bbf4307 --- /dev/null +++ b/tf_shell_ml/test/postscale_model_distrib_test.py @@ -0,0 +1,145 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import unittest +import os +import tensorflow as tf +import tempfile + +job_prefix = "tfshell" +features_party_job = f"{job_prefix}features" +labels_party_job = f"{job_prefix}labels" +features_party_dev = f"/job:{features_party_job}/replica:0/task:0/device:CPU:0" +labels_party_dev = f"/job:{labels_party_job}/replica:0/task:0/device:CPU:0" + + +class TestDistribModel(tf.test.TestCase): + job_name = None + + def test_model(self): + print(f"Job name: {self.job_name}") + # flush stdout + import sys + + sys.stdout.flush() + + cluster = tf.train.ClusterSpec( + { + f"{features_party_job}": ["localhost:2223"], + f"{labels_party_job}": ["localhost:2224"], + } + ) + + server = tf.distribute.Server( + cluster, + job_name=self.job_name, + task_index=0, + ) + + tf.config.experimental_connect_to_cluster(cluster) + + # Register the tf-shell ops. + import tf_shell + + if self.job_name == labels_party_job: + print(f"{self.job_name} server started.", flush=True) + server.join() + return + + import keras + import tf_shell_ml + import numpy as np + + # Prepare the dataset. (Note this must be after forking) + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784)) + x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0) + y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10) + + # Clip dataset images to limit memory usage. The model accuracy will be + # bad but this test only measures functionality. + x_train, x_test = x_train[:, :250], x_test[:, :250] + + # Set a seed for shuffling both features and labels the same way. + seed = 42 + + with tf.device(labels_party_dev): + labels_dataset = tf.data.Dataset.from_tensor_slices(y_train) + labels_dataset = labels_dataset.batch(2**10) + + with tf.device(features_party_dev): + features_dataset = tf.data.Dataset.from_tensor_slices(x_train) + features_dataset = features_dataset.batch(2**10) + + val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + val_dataset = val_dataset.batch(32) + + cache_dir = tempfile.TemporaryDirectory() + cache = cache_dir.name + + m = tf_shell_ml.PostScaleSequential( + [ + tf.keras.layers.Dense(64, activation="relu"), + tf.keras.layers.Dense(10, activation="softmax"), + ], + lambda: tf_shell.create_autocontext64( + log2_cleartext_sz=23, + scaling_factor=32, + noise_offset_log2=14, + cache_path=cache, + ), + lambda: tf_shell.create_autocontext64( + log2_cleartext_sz=24, + scaling_factor=1, + noise_offset_log2=0, + cache_path=cache, + ), + labels_party_dev=labels_party_dev, + features_party_dev=features_party_dev, + cache_path=cache, + ) + + m.compile( + shell_loss=tf_shell_ml.CategoricalCrossentropy(), + optimizer=tf.keras.optimizers.Adam(0.1), + metrics=[tf.keras.metrics.CategoricalAccuracy()], + ) + + history = m.fit( + features_dataset, + labels_dataset, + steps_per_epoch=2, + epochs=1, + verbose=2, + validation_data=val_dataset, + ) + + cache_dir.cleanup() + + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3) + + +if __name__ == "__main__": + labels_pid = os.fork() + if labels_pid == 0: # child process + TestDistribModel.job_name = labels_party_job + tf.test.main() + os._exit(0) + + TestDistribModel.job_name = features_party_job + tf.test.main() + + os.waitpid(labels_pid, 0) + print("Both parties finished.") diff --git a/tf_shell_ml/test/postscale_model_local_test.py b/tf_shell_ml/test/postscale_model_local_test.py index 6d51a85..6652bbe 100644 --- a/tf_shell_ml/test/postscale_model_local_test.py +++ b/tf_shell_ml/test/postscale_model_local_test.py @@ -20,10 +20,11 @@ import tf_shell import tf_shell_ml import os +import tempfile class TestModel(tf.test.TestCase): - def _test_model(self, disable_encryption, disable_masking, disable_noise): + def _test_model(self, disable_encryption, disable_masking, disable_noise, cache): # Prepare the dataset. (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784)) @@ -32,17 +33,17 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): # Clip dataset images to limit memory usage. The model accuracy will be # bad but this test only measures functionality. - x_train, x_test = x_train[:, :380], x_test[:, :380] + x_train, x_test = x_train[:, :350], x_test[:, :350] - train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(2**12) + labels_dataset = tf.data.Dataset.from_tensor_slices(y_train) + labels_dataset = labels_dataset.batch(2**10) + + features_dataset = tf.data.Dataset.from_tensor_slices(x_train) + features_dataset = features_dataset.batch(2**10) val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) val_dataset = val_dataset.batch(32) - context_cache_path = "/tmp/postscale_model_local_test_cache/" - os.makedirs(context_cache_path, exist_ok=True) - m = tf_shell_ml.PostScaleSequential( [ tf.keras.layers.Dense(64, activation="relu"), @@ -52,18 +53,18 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): log2_cleartext_sz=23, scaling_factor=32, noise_offset_log2=14, - cache_path=context_cache_path, + cache_path=cache, ), lambda: tf_shell.create_autocontext64( log2_cleartext_sz=24, scaling_factor=1, noise_offset_log2=0, - cache_path=context_cache_path, + cache_path=cache, ), disable_encryption=disable_encryption, disable_masking=disable_masking, disable_noise=disable_noise, - cache_path=context_cache_path, + cache_path=cache, ) m.compile( @@ -73,8 +74,9 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): ) history = m.fit( - train_dataset, - steps_per_epoch=8, + features_dataset, + labels_dataset, + steps_per_epoch=2, epochs=1, verbose=2, validation_data=val_dataset, @@ -83,10 +85,13 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25) def test_model(self): - self._test_model(False, False, False) - self._test_model(True, False, False) - self._test_model(False, True, False) - self._test_model(False, False, True) + with tempfile.TemporaryDirectory() as cache_dir: + # Perform full encrypted test to populate cache. + self._test_model(False, False, False, cache_dir) + self._test_model(True, False, False, cache_dir) + self._test_model(False, True, False, cache_dir) + self._test_model(False, False, True, cache_dir) + self._test_model(True, True, True, cache_dir) if __name__ == "__main__":