diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index 782b033..0931b77 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -26,11 +26,14 @@ def __init__( self, layers, shell_context_fn, - use_encryption, labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", needs_public_rotation_key=False, noise_multiplier=1.0, + disable_encryption=False, + disable_masking=False, + disable_noise=False, + check_overflow_close=True, *args, **kwargs, ): @@ -48,11 +51,14 @@ def __init__( self.layers[-1].activation_deriv = None self.shell_context_fn = shell_context_fn - self.use_encryption = use_encryption self.labels_party_dev = labels_party_dev self.features_party_dev = features_party_dev - self.clipping_threshold = 10000000 self.needs_public_rotation_key = needs_public_rotation_key + self.noise_multiplier = noise_multiplier + self.disable_encryption = disable_encryption + self.disable_masking = disable_masking + self.disable_noise = disable_noise + self.check_overflow_close = check_overflow_close def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs): super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs) @@ -78,14 +84,40 @@ def build(self, input_shape): if hasattr(l, "unpacking_funcs"): self.unpacking_funcs.extend(l.unpacking_funcs()) + def compute_max_two_norm_and_pred(self, features): + with tf.GradientTape(persistent=tf.executing_eagerly()) as tape: + y_pred = self(features, training=True) # forward pass + grads = tape.jacobian( + y_pred, + self.trainable_variables, + parallel_iterations=1, + experimental_use_pfor=False, + ) + # ^ layers list x (batch size x num output classes x weights) matrix + + if len(grads) == 0: + raise ValueError("No gradients found") + slot_size = tf.shape(grads[0])[0] + num_output_classes = grads[0].shape[1] + + flat_grads = [tf.reshape(g, [slot_size, num_output_classes, -1]) for g in grads] + # ^ layers list (batch_size x num output classes x flattened layer weights) + all_grads = tf.concat(flat_grads, axis=2) + # ^ batch_size x num output classes x all flattened weights + two_norms = tf.map_fn(lambda x: tf.norm(x, axis=0), all_grads) + max_two_norm = tf.reduce_max(two_norms) + return max_two_norm, y_pred + def train_step(self, data): x, y = data with tf.device(self.labels_party_dev): - if self.use_encryption: + if self.disable_encryption: + enc_y = y + else: key_path = tempfile.mkdtemp() # Every trace gets a new key. - shell_context = self.shell_context_fn() + shell_context = self.shell_context_fn() secret_key = tf_shell.create_key64( shell_context, key_path + "/secret_key" ) @@ -98,12 +130,11 @@ def train_step(self, data): ) # Encrypt the batch of secret labels y. enc_y = tf_shell.to_encrypted(y, secret_key, shell_context) - else: - enc_y = y with tf.device(self.features_party_dev): # Forward pass in plaintext. - y_pred = self(x, training=True) + # y_pred = self(x, training=True) + max_two_norm, y_pred = self.compute_max_two_norm_and_pred(x) # Backward pass. dx = self.loss_fn.grad(enc_y, y_pred) @@ -115,7 +146,12 @@ def train_step(self, data): else: dw, dx = l.backward( dJ_dx[-1], - public_rotation_key if self.needs_public_rotation_key else None, + ( + public_rotation_key + if not self.disable_encryption + and self.needs_public_rotation_key + else None + ), ) dJ_dw.extend(dw) dJ_dx.append(dx) @@ -123,30 +159,38 @@ def train_step(self, data): if len(dJ_dw) == 0: raise ValueError("No gradients found.") - # Setup parameters for the masking. - t = tf.cast(shell_context.plaintext_modulus, tf.float32) - t_half = t // 2 - mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)] - - # Mask the encrypted grads to prepare for decryption. - mask = [ - tf.random.uniform( - tf_shell.shape(g), - dtype=tf.float32, - minval=-t_half / s, - maxval=t_half / s, - ) - for g, s in zip(reversed(dJ_dw), mask_scaling_factors) - # tf.zeros_like(tf_shell.shape(g), dtype=tf.int64) - # for g in dJ_dw - ] + # Mask the encrypted grads to prepare for decryption. The masks may + # overflow during the reduce_sum over the batch. When the masks are + # operated on, they are multiplied by the scaling factor, so it is + # not necessary to mask the full range -t/2 to t/2. (Though it is + # possible, it unnecessarily introduces noise into the ciphertext.) + if self.disable_masking or self.disable_encryption: + masked_enc_grads = [g for g in reversed(dJ_dw)] + else: + t = tf.cast(shell_context.plaintext_modulus, tf.float32) + t_half = t // 2 + mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)] + mask = [ + tf.random.uniform( + tf_shell.shape(g), + dtype=tf.float32, + minval=-t_half / s, + maxval=t_half / s, + ) + for g, s in zip(reversed(dJ_dw), mask_scaling_factors) + # tf.zeros_like(tf_shell.shape(g), dtype=tf.int64) + # for g in dJ_dw + ] - # Mask the encrypted gradients and reverse the order to match the - # order of the layers. - masked_enc_grads = [(g + m) for g, m in zip(reversed(dJ_dw), mask)] + # Mask the encrypted gradients and reverse the order to match + # the order of the layers. + masked_enc_grads = [(g + m) for g, m in zip(reversed(dJ_dw), mask)] with tf.device(self.labels_party_dev): - if self.use_encryption: + if self.disable_encryption: + # Unpacking is not necessary when not using encryption. + masked_grads = masked_enc_grads + else: # Decrypt the weight gradients. packed_masked_grads = [ tf_shell.to_tensorflow( @@ -164,47 +208,56 @@ def train_step(self, data): masked_grads = [ f(g) for f, g in zip(self.unpacking_funcs, packed_masked_grads) ] - else: - masked_grads = masked_enc_grads with tf.device(self.features_party_dev): - # SHELL represents floats as integers between [0, t) where t is the - # plaintext modulus. To mimic the modulo operation without SHELL, - # numbers which exceed the range [-t/2, t/2) are shifted back into - # the range. - def rebalance(x_list, t_half, scaling_factor_list): - x_list = [ - tf.where(x > t_half / s + (1 / s - 1e-6), x - t / s, x) - for x, s in zip(x_list, scaling_factor_list) - ] - x_list = [ - tf.where(x < -t_half / s - (1 / s - 1e-6), x + t / s, x) - for x, s in zip(x_list, scaling_factor_list) + if self.disable_masking or self.disable_encryption: + grads = masked_grads + else: + # SHELL represents floats as integers between [0, t) where t is the + # plaintext modulus. To mimic the modulo operation without SHELL, + # numbers which exceed the range [-t/2, t/2) are shifted back into + # the range. + epsilon = tf.constant(1e-6, dtype=float) + + def rebalance(x, s): + r_bound = t_half / s + epsilon + l_bound = -t_half / s - epsilon + t_over_s = t / s + x = tf.where(x > r_bound, x - t_over_s, x) + x = tf.where(x < l_bound, x + t_over_s, x) + return x + + # Unmask the gradients using the mask. The unpacking function may + # sum the mask from two of the gradients (one from each batch), so + # the mask must be brought back into the range of [-t/2, t/2] before + # subtracting it from the gradient, and again after. + unpacked_mask = [f(m) for f, m in zip(self.unpacking_funcs, mask)] + unpacked_mask = [ + rebalance(m, s) for m, s in zip(unpacked_mask, mask_scaling_factors) ] - return x_list - - # Unmask the gradients using the mask. The unpacking function may - # sum the mask from two of the gradients (one from each batch), so - # the mask must be brought back into the range of [-t/2, t/2] before - # subtracting it from the gradient, and again after. - unpacked_mask = [f(m) for f, m in zip(self.unpacking_funcs, mask)] - unpacked_mask = rebalance(unpacked_mask, t_half, mask_scaling_factors) - grads = [mg - m for mg, m in zip(masked_grads, unpacked_mask)] - grads = rebalance(grads, t_half, mask_scaling_factors) + grads = [mg - m for mg, m in zip(masked_grads, unpacked_mask)] + grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)] # TODO: set stddev based on clipping threshold. - noise = [ - # tf.random.normal(tf.shape(g), stddev=1, dtype=float) for g in grads - tf.zeros(tf.shape(g)) - for g in grads - ] - noised_grads = [g + n for g, n in zip(grads, noise)] + if self.disable_noise: + noised_grads = grads + else: + noise = [ + tf.random.normal( + tf.shape(g), + stddev=max_two_norm * self.noise_multiplier, + dtype=float, + ) + # tf.zeros(tf.shape(g)) + for g in grads + ] + noised_grads = [g + n for g, n in zip(grads, noise)] # Apply the gradients to the model. self.optimizer.apply_gradients(zip(noised_grads, self.weights)) # Do not update metrics during secure training. - if not self.use_encryption: + if self.disable_encryption: # Update metrics (includes the metric that tracks the loss) for metric in self.metrics: if metric.name == "loss": @@ -215,9 +268,8 @@ def rebalance(x_list, t_half, scaling_factor_list): metric_results = {m.name: m.result() for m in self.metrics} else: - metric_results = {} + metric_results = {"num_slots": shell_context.num_slots} - metric_results["num_slots"] = shell_context.num_slots return metric_results def test_step(self, data): diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index f8f24b5..4f4b0ee 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -37,13 +37,24 @@ def train_step_tf_func(self, data): # to the same value as the encryption ring degree. Run the training loop once # on dummy data to figure out the batch size. def prep_dataset_for_model(self, train_dataset): - if not self.use_encryption: + if self.disable_encryption: self.dataset_prepped = True - return + return train_dataset # Run the training loop once on dummy data to figure out the batch size. tf.config.run_functions_eagerly(False) metrics = self.train_step_tf_func(next(iter(train_dataset))) + + if not isinstance(metrics, dict): + raise ValueError( + f"Expected train_step to return a dict, got {type(metrics)}." + ) + + if "num_slots" not in metrics: + raise ValueError( + f"Expected train_step to return a dict with key 'num_slots', got {metrics.keys()}." + ) + train_dataset = train_dataset.rebatch( metrics["num_slots"].numpy(), drop_remainder=True ) @@ -56,8 +67,8 @@ def prep_dataset_for_model(self, train_dataset): # `prep_dataset_for_model` because it does not execute the graph, instead # tracing and optimizing the graph and extracting the required parameters. def fast_prep_dataset_for_model(self, train_dataset): - if not self.use_encryption: - return + if not self.disable_encryption: + return train_dataset # Call the training step with keygen to trace the graph. Use a copy # of the function to avoid caching the trace. diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index 16d1870..5e9c3ab 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -25,11 +25,13 @@ def __init__( self, layers, shell_context_fn, - use_encryption, labels_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", features_party_dev="/job:localhost/replica:0/task:0/device:CPU:0", - needs_public_rotation_key=False, noise_multiplier=1.0, + disable_encryption=False, + disable_masking=False, + disable_noise=False, + check_overflow_close=True, *args, **kwargs, ): @@ -47,13 +49,15 @@ def __init__( self.layers[-1].activation_deriv = None self.shell_context_fn = shell_context_fn - self.use_encryption = use_encryption self.labels_party_dev = labels_party_dev self.features_party_dev = features_party_dev self.clipping_threshold = 10000000 self.context_prepped = False - self.needs_public_rotation_key = needs_public_rotation_key + self.disable_encryption = disable_encryption self.noise_multiplier = noise_multiplier + self.disable_masking = disable_masking + self.disable_noise = disable_noise + self.check_overflow_close = check_overflow_close def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs): super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs) @@ -75,24 +79,16 @@ def train_step(self, data): x, y = data with tf.device(self.labels_party_dev): - if self.use_encryption: + if self.disable_encryption: + enc_y = y + else: key_path = tempfile.mkdtemp() # Every trace gets a new key. shell_context = self.shell_context_fn() - secret_key = tf_shell.create_key64( shell_context, key_path + "/secret_key" ) - secret_fast_rotation_key = tf_shell.create_fast_rotation_key64( - shell_context, secret_key, key_path + "/secret_fast_rotation_key" - ) - if self.needs_public_rotation_key: - public_rotation_key = tf_shell.create_rotation_key64( - shell_context, secret_key, key_path + "/public_rotation_key" - ) # Encrypt the batch of secret labels y. enc_y = tf_shell.to_encrypted(y, secret_key, shell_context) - else: - enc_y = y with tf.device(self.features_party_dev): # Unset the activation function for the last layer so it is not used in @@ -100,7 +96,7 @@ def train_step(self, data): # is factored out of the gradient computation and accounted for below. self.layers[-1].activation = tf.keras.activations.linear - with tf.GradientTape() as tape: + with tf.GradientTape(persistent=tf.executing_eagerly()) as tape: y_pred = self(x, training=True) # forward pass grads = tape.jacobian( y_pred, @@ -129,9 +125,11 @@ def train_step(self, data): # Remember the original shape of the gradient in order to unpack them # after the multiplication so they can be applied to the model. + # Get the shapes from TensorFlow's tensors, not SHELL's context + # for when the batch size != slotting dim or not using encryption. if len(grads) == 0: raise ValueError("No gradients found") - slot_size = grads[0].shape[0] + slot_size = tf.shape(grads[0])[0] num_output_classes = grads[0].shape[1] grad_shapes = [g.shape[2:] for g in grads] flattened_grad_shapes = [s.num_elements() for s in grad_shapes] @@ -159,18 +157,6 @@ def train_step(self, data): # Sum over the output classes. scaled_grads = tf_shell.reduce_sum(scaled_grads, axis=1) - # Sum over the batch. - if self.use_encryption: - if self.needs_public_rotation_key: - scaled_grads = tf_shell.reduce_sum( - scaled_grads, axis=0, rotation_key=public_rotation_key - ) - else: - scaled_grads = tf_shell.fast_reduce_sum(scaled_grads) - else: - scaled_grads = tf_shell.reduce_sum(scaled_grads, axis=0) - # ^ batch_size x flattened weights - # Split to recover the gradients by layer. ps_grads = tf_shell.split(scaled_grads, flattened_grad_shapes, axis=1) # ^ layers list (batch_size x flat layer weights) @@ -179,95 +165,118 @@ def train_step(self, data): ps_grads = [ tf_shell.reshape( g, - tf.concat( - [[shell_context.num_slots], tf.cast(s, dtype=tf.int64)], axis=0 - ), + tf.concat([[slot_size], tf.cast(s, dtype=tf.int64)], axis=0), ) for g, s in zip(ps_grads, grad_shapes) ] - # This cast is safe because the plaintext modulus will always be - # less than 63 bits. - int_pt_modulus = tf.cast(shell_context.plaintext_modulus, dtype=tf.int64) - - # Setup parameters for the masking. - t = tf.cast(shell_context.plaintext_modulus, tf.float32) - t_half = t // 2 - mask_scaling_factors = [g._scaling_factor for g in ps_grads] - - # Mask the encrypted grads to prepare for decryption. - mask = [ - tf.random.uniform( - tf_shell.shape(g), - dtype=tf.float32, - minval=-t_half / s, - maxval=t_half / s, - ) - for g, s in zip(ps_grads, mask_scaling_factors) - # tf.zeros_like(tf_shell.shape(g), dtype=tf.int64) - # for g in dJ_dw - ] + # Mask the encrypted grads to prepare for decryption. The masks may + # overflow during the reduce_sum over the batch. When the masks are + # operated on, they are multiplied by the scaling factor, so it is + # not necessary to mask the full range -t/2 to t/2. (Though it is + # possible, it unnecessarily introduces noise into the ciphertext.) + if self.disable_masking or self.disable_encryption: + masked_enc_grads = ps_grads + else: + t = tf.cast(shell_context.plaintext_modulus, tf.float32) + t_half = t // 2 + mask_scaling_factors = [g._scaling_factor for g in ps_grads] + masks = [ + tf.random.uniform( + tf_shell.shape(g), + dtype=tf.float32, + minval=-t_half / s, + maxval=t_half / s, + ) + for g, s in zip(ps_grads, mask_scaling_factors) + ] - # Mask the encrypted gradients and reverse the order to match the - # order of the layers. - masked_enc_grads = [g + m for g, m in zip(ps_grads, mask)] + # Mask the encrypted gradients to prepare for decryption. + masked_enc_grads = [g + m for g, m in zip(ps_grads, masks)] with tf.device(self.labels_party_dev): - if self.use_encryption: + if self.disable_encryption: + masked_grads = masked_enc_grads + else: # Decrypt the weight gradients. - packed_masked_grads = [ - tf_shell.to_tensorflow( - g, - secret_fast_rotation_key if g._is_fast_rotated else secret_key, - ) - for g in ps_grads + masked_grads = [ + tf_shell.to_tensorflow(g, secret_key) for g in masked_enc_grads ] - # Unpack the plaintext gradients using the corresponding layer's - # unpack function. - masked_grads = self._unpack(packed_masked_grads) + # Sum the masked gradients over the batch. + if self.disable_masking or self.disable_encryption: + batch_masked_grads = [tf.reduce_sum(mg, 0) for mg in masked_grads] else: - masked_grads = ps_grads + batch_masked_grads = [ + tf_shell.reduce_sum_with_mod(mg, 0, shell_context, s) + for mg, s in zip(masked_grads, mask_scaling_factors) + ] with tf.device(self.features_party_dev): - # SHELL represents floats as integers between [0, t) where t is the - # plaintext modulus. To mimic SHELL's modulo operations in - # TensorFlow, numbers which exceed the range [-t/2, t/2) are shifted - # back into the range. - def rebalance(x_list, t_half, scaling_factor_list): - x_list = [ - tf.where(x > t_half / s + (1 / s - 1e-6), x - t / s, x) - for x, s in zip(x_list, scaling_factor_list) + if self.disable_masking or self.disable_encryption: + batch_grad = batch_masked_grads + else: + # Sum the masks over the batch. + batch_masks = [ + tf_shell.reduce_sum_with_mod(m, 0, shell_context, s) + for m, s in zip(masks, mask_scaling_factors) + ] + + # Unmask the batch gradient. + batch_grad = [mg - m for mg, m in zip(batch_masked_grads, batch_masks)] + + # SHELL represents floats as integers between [0, t) where t is the + # plaintext modulus. To mimic SHELL's modulo operations in + # TensorFlow, numbers which exceed the range [-t/2, t/2] are shifted + # back into the range. + epsilon = tf.constant(1e-6, dtype=float) + + def rebalance(x, s): + r_bound = t_half / s + epsilon + l_bound = -t_half / s - epsilon + t_over_s = t / s + x = tf.where(x > r_bound, x - t_over_s, x) + x = tf.where(x < l_bound, x + t_over_s, x) + return x + + batch_grad = [ + rebalance(g, s) for g, s in zip(batch_grad, mask_scaling_factors) ] - x_list = [ - tf.where(x < -t_half / s - (1 / s - 1e-6), x + t / s, x) - for x, s in zip(x_list, scaling_factor_list) + + if not self.disable_encryption and self.check_overflow_close: + # If the unmasked gradient is between [-t/2, -t/4] or + # [t/4, t/2], the gradient may have overflowed. + overflowed = [tf.abs(g) > t_half / 2 for g in batch_grad] + overflowed = [tf.reduce_any(o) for o in overflowed] + overflowed = tf.reduce_any(overflowed) + tf.cond( + overflowed, + lambda: tf.print("Gradient may have overflowed"), + lambda: tf.identity(overflowed), + ) + + if self.disable_noise: + noised_grads = batch_grad + else: + # Set the noise based on the maximum two norm of the gradient + # per example, per output. + noise = [ + tf.random.normal( + tf.shape(g), + stddev=max_two_norm * self.noise_multiplier, + dtype=float, + ) + for g in batch_grad ] - return x_list - - # Unmask the gradients using the mask. The mask must be unpacked, - # and modulo the plaintext modulus. This can be done with two - # subtractions. - unpacked_mask = self._unpack(mask) - unpacked_mask = rebalance(unpacked_mask, t_half, mask_scaling_factors) - dec_grads = [mg - m for mg, m in zip(masked_grads, unpacked_mask)] - dec_grads = rebalance(dec_grads, t_half, mask_scaling_factors) - - # Set the noise based on the maximum two norm of the gradient per - # example, per output. - noise = [ - tf.random.normal(tf.shape(g), stddev=max_two_norm * self.noise_multiplier, dtype=float) - for g in dec_grads - ] - # ^ layers list (batch_size x weights) + # ^ layers list (batch_size x weights) - noised_grads = [g + n for g, n in zip(dec_grads, noise)] + noised_grads = [g + n for g, n in zip(batch_grad, noise)] # Apply the gradients to the model. self.optimizer.apply_gradients(zip(noised_grads, self.weights)) # Do not update metrics during secure training. - if not self.use_encryption: + if self.disable_encryption: # Update metrics (includes the metric that tracks the loss) for metric in self.metrics: if metric.name == "loss": @@ -278,7 +287,6 @@ def rebalance(x_list, t_half, scaling_factor_list): metric_results = {m.name: m.result() for m in self.metrics} else: - metric_results = {} - metric_results["num_slots"] = shell_context.num_slots + metric_results = {"num_slots": shell_context.num_slots} return metric_results diff --git a/tf_shell_ml/test/dpsgd_model_distrib_test.py b/tf_shell_ml/test/dpsgd_model_distrib_test.py index b23b16a..a1db690 100644 --- a/tf_shell_ml/test/dpsgd_model_distrib_test.py +++ b/tf_shell_ml/test/dpsgd_model_distrib_test.py @@ -74,7 +74,7 @@ def test_model(self): x_train, x_test = x_train[:, :120], x_test[:, :120] train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**14).batch(4) + train_dataset = train_dataset.shuffle(buffer_size=2**14).batch(2**12) val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) val_dataset = val_dataset.batch(32) @@ -102,9 +102,8 @@ def test_model(self): lambda: tf_shell.create_autocontext64( log2_cleartext_sz=32, scaling_factor=3, - noise_offset_log2=57, + noise_offset_log2=25, ), - True, labels_party_dev=labels_party_dev, features_party_dev=features_party_dev, ) @@ -117,11 +116,13 @@ def test_model(self): ) history = m.fit( - train_dataset.take(2**13), + train_dataset.take(4), epochs=1, validation_data=val_dataset, ) + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3) + if __name__ == "__main__": labels_pid = os.fork() diff --git a/tf_shell_ml/test/dpsgd_model_local_test.py b/tf_shell_ml/test/dpsgd_model_local_test.py index 498a4f8..6774dff 100644 --- a/tf_shell_ml/test/dpsgd_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_model_local_test.py @@ -22,7 +22,7 @@ class TestModel(tf.test.TestCase): - def test_model(self): + def _test_model(self, disable_encryption, disable_masking, disable_noise): # Prepare the dataset. (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784)) @@ -34,7 +34,7 @@ def test_model(self): x_train, x_test = x_train[:, :512], x_test[:, :512] train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(4) + train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(2**12) val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) val_dataset = val_dataset.batch(32) @@ -57,11 +57,13 @@ def test_model(self): ), ], lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=32, - scaling_factor=3, - noise_offset_log2=64, + log2_cleartext_sz=14, + scaling_factor=8, + noise_offset_log2=12, ), - use_encryption=True, + disable_encryption=False, + disable_masking=False, + disable_noise=False, ) m.compile( @@ -75,9 +77,15 @@ def test_model(self): m.build([None, 512]) m.summary() - history = m.fit( - train_dataset.take(2**13), epochs=1, validation_data=val_dataset - ) + history = m.fit(train_dataset.take(4), epochs=1, validation_data=val_dataset) + + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3) + + def test_model(self): + self._test_model(False, False, False) + self._test_model(True, False, False) + self._test_model(False, True, False) + self._test_model(False, False, True) if __name__ == "__main__": diff --git a/tf_shell_ml/test/postscale_model_local_test.py b/tf_shell_ml/test/postscale_model_local_test.py index 1911e00..da95b56 100644 --- a/tf_shell_ml/test/postscale_model_local_test.py +++ b/tf_shell_ml/test/postscale_model_local_test.py @@ -22,7 +22,7 @@ class TestModel(tf.test.TestCase): - def test_model(self): + def _test_model(self, disable_encryption, disable_masking, disable_noise): # Prepare the dataset. (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784)) @@ -31,10 +31,10 @@ def test_model(self): # Clip dataset images to limit memory usage. The model accuracy will be # bad but this test only measures functionality. - x_train, x_test = x_train[:, :256], x_test[:, :256] + x_train, x_test = x_train[:, :380], x_test[:, :380] train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(4) + train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(2**12) val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) val_dataset = val_dataset.batch(32) @@ -44,18 +44,14 @@ def test_model(self): tf.keras.layers.Dense(64, activation="relu"), tf.keras.layers.Dense(10, activation="sigmoid"), ], - # lambda: tf_shell.create_context64( - # log_n=12, - # main_moduli=[288230376151760897, 288230376152137729], - # plaintext_modulus=4294991873, - # scaling_factor=3, - # ), lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=32, - scaling_factor=3, - noise_offset_log2=50, + log2_cleartext_sz=14, + scaling_factor=2**10, + noise_offset_log2=47, # may be overprovisioned ), - use_encryption=True, + disable_encryption=False, + disable_masking=False, + disable_noise=False, ) m.compile( @@ -65,9 +61,15 @@ def test_model(self): metrics=[tf.keras.metrics.CategoricalAccuracy()], ) - history = m.fit( - train_dataset.take(2**13), epochs=1, validation_data=val_dataset - ) + history = m.fit(train_dataset.take(4), epochs=1, validation_data=val_dataset) + + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3) + + def test_model(self): + self._test_model(False, False, False) + self._test_model(True, False, False) + self._test_model(False, True, False) + self._test_model(False, False, True) if __name__ == "__main__":