diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index b739ab3..de85155 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -33,7 +33,7 @@ def __init__( disable_encryption=False, disable_masking=False, disable_noise=False, - check_overflow=False, + check_overflow_INSECURE=False, cache_path=None, jacobian_pfor=False, jacobian_pfor_iterations=None, @@ -62,7 +62,7 @@ def __init__( self.disable_encryption = disable_encryption self.disable_masking = disable_masking self.disable_noise = disable_noise - self.check_overflow = check_overflow + self.check_overflow_INSECURE = check_overflow_INSECURE self.cache_path = cache_path self.jacobian_pfor = jacobian_pfor self.jacobian_pfor_iterations = jacobian_pfor_iterations @@ -99,7 +99,7 @@ def compute_max_two_norm_and_pred(self, features, skip_two_norm): grads = tape.jacobian( y_pred, self.trainable_variables, - unconnected_gradients='zero', + unconnected_gradients="zero", parallel_iterations=self.jacobian_pfor_iterations, experimental_use_pfor=self.jacobian_pfor, ) @@ -143,8 +143,8 @@ def train_step(self, data): # Backward pass. dx = self.loss_fn.grad(enc_y, y_pred) - dJ_dw = [] - dJ_dx = [dx] + dJ_dw = [] # Derivatives of the loss with respect to the weights. + dJ_dx = [dx] # Derivatives of the loss with respect to the inputs. for l in reversed(self.layers): if isinstance(l, tf_shell_ml.GlobalAveragePooling1D): dw, dx = l.backward(dJ_dx[-1]) @@ -153,6 +153,29 @@ def train_step(self, data): dJ_dw.extend(dw) dJ_dx.append(dx) + # Check if the backproped gradients overflowed. + if not self.disable_encryption and self.check_overflow_INSECURE: + # Note, checking the backprop gradients requires decryption + # on the features party which breaks security of the protocol. + bp_scaling_factors = [g._scaling_factor for g in dJ_dw] + dec_dJ_dw = [ + tf_shell.to_tensorflow( + g, + ( + backprop_secret_fastrot_key + if g._is_fast_rotated + else backprop_secret_key + ), + ) + for g in dJ_dw + ] + self.warn_on_overflow( + dec_dJ_dw, + bp_scaling_factors, + backprop_context.plaintext_modulus, + "WARNING: Backprop gradient may have overflowed.", + ) + # Mask the encrypted grads to prepare for decryption. The masks may # overflow during the reduce_sum over the batch. When the masks are # operated on, they are multiplied by the scaling factor, so it is @@ -214,9 +237,16 @@ def train_step(self, data): grads = [f(g) for f, g in zip(self.unpacking_funcs, packed_grads)] if not self.disable_noise: - # Efficiently pack the masked gradients to prepare for - # encryption. This is special because the masked gradients are - # no longer batched so the packing must be done manually. + tf.assert_equal( + backprop_context.num_slots, + noise_context.num_slots, + message="Backprop and noise contexts must have the same number of slots.", + ) + + # Efficiently pack the masked gradients to prepare for adding + # the encrypted noise. This is special because the masked + # gradients are no longer batched, so the packing must be done + # manually. ( flat_grads, grad_shapes, @@ -240,6 +270,16 @@ def train_step(self, data): # The gradients must be first be decrypted using the noise # secret key. flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key) + + if self.check_overflow_INSECURE: + nosie_scaling_factors = grads._scaling_factor + self.warn_on_overflow( + [flat_grads], + [nosie_scaling_factors], + noise_context.plaintext_modulus, + "WARNING: Noised gradient may have overflowed.", + ) + # Unpack the noise after decryption. grads = self.unflatten_and_unpad_grad( flat_grads, @@ -249,10 +289,12 @@ def train_step(self, data): ) if not self.disable_masking and not self.disable_encryption: - # SHELL represents floats as integers between [0, t) where t is the - # plaintext modulus. To mimic the modulo operation without SHELL, - # numbers which exceed the range [-t/2, t/2) are shifted back into - # the range. + # SHELL represents floats as integers between [0, t) where t is + # the plaintext modulus. To mimic the modulo operation without + # SHELL, numbers which exceed the range [-t/2, t/2) are shifted + # back into the range. In this context, t is the plaintext + # modulus of the backprop context, since that what the gradients + # were encrypted with when the mask was added. epsilon = tf.constant(1e-6, dtype=float) def rebalance(x, s): @@ -274,22 +316,6 @@ def rebalance(x, s): grads = [mg - m for mg, m in zip(grads, unpacked_mask)] grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)] - if not self.disable_encryption and self.check_overflow: - # If the unmasked gradient is between [-t/2, -t/4] or - # [t/4, t/2], the gradient may have overflowed. Note this must - # also take the scaling factor into account. - overflowed = [ - tf.abs(g) > t_half / 2 / s - for g, s in zip(grads, mask_scaling_factors) - ] - overflowed = [tf.reduce_any(o) for o in overflowed] - overflowed = tf.reduce_any(overflowed) - tf.cond( - overflowed, - lambda: tf.print("WARNING: Gradient may have overflowed."), - lambda: tf.identity(overflowed), - ) - # Apply the gradients to the model. self.optimizer.apply_gradients(zip(grads, self.weights)) diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index 3add924..17cbc61 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -242,3 +242,30 @@ def unflatten_and_unpad_grad( # Reshape to the original shapes. grads_list = [tf.reshape(g, s) for g, s in zip(grads_list, grad_shapes)] return grads_list + + def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message): + # If the gradient is between [-t/2, -t/4] or [t/4, t/2], the gradient + # may have overflowed. This also must take the scaling factor into + # account so the range is divided by the scaling factor. + t = tf.cast(plaintext_modulus, grads[0].dtype) + t_half = t / 2 + + over_by = [ + tf.reduce_max(tf.abs(g) - t_half / 2 / s) + for g, s in zip(grads, scaling_factors) + ] + max_over_by = tf.reduce_max(over_by) + overflowed = tf.reduce_any(max_over_by > 0) + + tf.cond( + overflowed, + lambda: tf.print( + message, + "Overflowed by", + over_by, + "(positive number indicates overflow amount).", + ), + lambda: tf.identity(overflowed), + ) + + return overflowed diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index 36056b6..0fc63e3 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -31,8 +31,10 @@ def __init__( disable_encryption=False, disable_masking=False, disable_noise=False, - check_overflow=False, + check_overflow_INSECURE=False, cache_path=None, + jacobian_pfor=False, + jacobian_pfor_iterations=None, *args, **kwargs, ): @@ -59,8 +61,10 @@ def __init__( self.noise_multiplier = noise_multiplier self.disable_masking = disable_masking self.disable_noise = disable_noise - self.check_overflow = check_overflow + self.check_overflow_INSECURE = check_overflow_INSECURE self.cache_path = cache_path + self.jacobian_pfor = jacobian_pfor + self.jacobian_pfor_iterations = jacobian_pfor_iterations def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs): super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs) @@ -101,8 +105,9 @@ def train_step(self, data): grads = tape.jacobian( y_pred, self.trainable_variables, - parallel_iterations=1, - experimental_use_pfor=False, + unconnected_gradients="zero", + parallel_iterations=self.jacobian_pfor_iterations, + experimental_use_pfor=self.jacobian_pfor, ) # grads = tape.jacobian(y_pred, self.trainable_variables, experimental_use_pfor=True) # ^ layers list x (batch size x num output classes x weights) matrix @@ -147,6 +152,22 @@ def train_step(self, data): ) # ^ layers list (batch_size x weights) + # Check if the post-scaled gradients overflowed. + if not self.disable_encryption and self.check_overflow_INSECURE: + # Note, checking the backprop gradients requires decryption + # on the features party which breaks security of the protocol. + bp_scaling_factors = [g._scaling_factor for g in grads] + dec_grads = [ + tf_shell.to_tensorflow(g, backprop_secret_fastrot_key) + for g in grads + ] + self.warn_on_overflow( + dec_grads, + bp_scaling_factors, + backprop_context.plaintext_modulus, + "WARNING: Backprop gradient may have overflowed.", + ) + # Mask the encrypted grads to prepare for decryption. The masks may # overflow during the reduce_sum over the batch. When the masks are # operated on, they are multiplied by the scaling factor, so it is @@ -195,6 +216,12 @@ def train_step(self, data): ] if not self.disable_noise: + tf.assert_equal( + backprop_context.num_slots, + noise_context.num_slots, + message="Backprop and noise contexts must have the same number of slots.", + ) + # Efficiently pack the masked gradients to prepare for # encryption. This is special because the masked gradients are # no longer batched so the packing must be done manually. @@ -217,7 +244,16 @@ def train_step(self, data): if not self.disable_noise: # The gradients must be first be decrypted using the noise # secret key. - grads = tf_shell.to_tensorflow(grads, noise_secret_key) + flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key) + + if not self.disable_encryption and self.check_overflow_INSECURE: + nosie_scaling_factor = grads._scaling_factor + self.warn_on_overflow( + [flat_grads], + [nosie_scaling_factor], + noise_context.plaintext_modulus, + "WARNING: Noised gradient may have overflowed.", + ) # Unpack the noise after decryption. grads = self.unflatten_and_unpad_grad( flat_grads, @@ -252,22 +288,6 @@ def rebalance(x, s): grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)] - if not self.disable_encryption and self.check_overflow: - # If the unmasked gradient is between [-t/2, -t/4] or - # [t/4, t/2], the gradient may have overflowed. Note this must - # also take the scaling factor into account. - overflowed = [ - tf.abs(g) > t_half / 2 / s - for g, s in zip(grads, mask_scaling_factors) - ] - overflowed = [tf.reduce_any(o) for o in overflowed] - overflowed = tf.reduce_any(overflowed) - tf.cond( - overflowed, - lambda: tf.print("Gradient may have overflowed"), - lambda: tf.identity(overflowed), - ) - # Apply the gradients to the model. self.optimizer.apply_gradients(zip(grads, self.weights)) diff --git a/tf_shell_ml/test/dpsgd_conv_model_local_test.py b/tf_shell_ml/test/dpsgd_conv_model_local_test.py index bd71071..6e303f9 100644 --- a/tf_shell_ml/test/dpsgd_conv_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_conv_model_local_test.py @@ -80,24 +80,24 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): ), ], backprop_context_fn=lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=26, - scaling_factor=2, - noise_offset_log2=46, + log2_cleartext_sz=24, + scaling_factor=1, + noise_offset_log2=0, cache_path=context_cache_path, ), noise_context_fn=lambda: tf_shell.create_autocontext64( log2_cleartext_sz=26, - scaling_factor=2, - noise_offset_log2=47, + scaling_factor=1, + noise_offset_log2=0, cache_path=context_cache_path, ), disable_encryption=disable_encryption, disable_masking=disable_masking, disable_noise=disable_noise, cache_path=context_cache_path, - check_overflow=True, - jacobian_pfor=False, - jacobian_pfor_iterations=None, + # check_overflow_INSECURE=True, + # jacobian_pfor=True, + # jacobian_pfor_iterations=128, ) m.compile( @@ -110,7 +110,9 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): m.build([None, 28, 28, 1]) m.summary() - history = m.fit(train_dataset.take(9), epochs=1, validation_data=val_dataset) + history = m.fit(train_dataset.take(16), epochs=1, validation_data=val_dataset) + + context_cache.cleanup() self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.30) diff --git a/tf_shell_ml/test/dpsgd_model_distrib_test.py b/tf_shell_ml/test/dpsgd_model_distrib_test.py index 5bc14d2..ab04ba9 100644 --- a/tf_shell_ml/test/dpsgd_model_distrib_test.py +++ b/tf_shell_ml/test/dpsgd_model_distrib_test.py @@ -97,13 +97,13 @@ def test_model(self): ), ], lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=14, - scaling_factor=8, - noise_offset_log2=12, + log2_cleartext_sz=10, + scaling_factor=5, + noise_offset_log2=0, cache_path=context_cache_path, ), lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=14, + log2_cleartext_sz=11, scaling_factor=8, noise_offset_log2=0, cache_path=context_cache_path, diff --git a/tf_shell_ml/test/dpsgd_model_local_test.py b/tf_shell_ml/test/dpsgd_model_local_test.py index 7616852..7bfe911 100644 --- a/tf_shell_ml/test/dpsgd_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_model_local_test.py @@ -35,7 +35,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): x_train, x_test = x_train[:, :512], x_test[:, :512] train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(2**12) + train_dataset = train_dataset.shuffle(buffer_size=2**10).batch(2**10) val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) val_dataset = val_dataset.batch(32) @@ -62,20 +62,38 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): ], lambda: tf_shell.create_autocontext64( log2_cleartext_sz=14, - scaling_factor=8, - noise_offset_log2=12, + scaling_factor=4, + noise_offset_log2=54, + # Smaller possible scaing factors below. + # log2_cleartext_sz=14, + # scaling_factor=2, + # noise_offset_log2=54, + # log2_cleartext_sz=12, + # scaling_factor=1, + # noise_offset_log2=31, cache_path=context_cache_path, ), lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=14, - scaling_factor=8, + # log2_cleartext_sz=30, + # scaling_factor=32, + # noise_offset_log2=0, + log2_cleartext_sz=32, + scaling_factor=16, noise_offset_log2=0, + # Smaller possible scaing factors below. + # log2_cleartext_sz=30, + # scaling_factor=8, + # noise_offset_log2=0, + # log2_cleartext_sz=30, + # scaling_factor=4, + # noise_offset_log2=0, cache_path=context_cache_path, ), - disable_encryption=False, - disable_masking=False, - disable_noise=False, + disable_encryption=disable_encryption, + disable_masking=disable_masking, + disable_noise=disable_noise, cache_path=context_cache_path, + # check_overflow_INSECURE=True, ) m.compile( @@ -89,7 +107,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): m.build([None, 512]) m.summary() - history = m.fit(train_dataset.take(4), epochs=1, validation_data=val_dataset) + history = m.fit(train_dataset.take(2**5), epochs=1, validation_data=val_dataset) self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.30) diff --git a/tf_shell_ml/test/postscale_model_local_test.py b/tf_shell_ml/test/postscale_model_local_test.py index d30e2db..a9cf45a 100644 --- a/tf_shell_ml/test/postscale_model_local_test.py +++ b/tf_shell_ml/test/postscale_model_local_test.py @@ -50,19 +50,21 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise): ], lambda: tf_shell.create_autocontext64( log2_cleartext_sz=14, - scaling_factor=2**10, + scaling_factor=8, + # scaling_factor=2**10, noise_offset_log2=47, # may be overprovisioned cache_path=context_cache_path, ), lambda: tf_shell.create_autocontext64( log2_cleartext_sz=14, - scaling_factor=2**10, + scaling_factor=8, + # scaling_factor=2**10, noise_offset_log2=47, cache_path=context_cache_path, ), - disable_encryption=False, - disable_masking=False, - disable_noise=False, + disable_encryption=disable_encryption, + disable_masking=disable_masking, + disable_noise=disable_noise, cache_path=context_cache_path, )