From c99465b4c29cfa453503feddb7afae89ceae2e78 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Fri, 1 Nov 2024 19:02:51 +0000 Subject: [PATCH] Ring formatter. --- tf_shell_ml/dpsgd_sequential_model.py | 8 ++++++-- tf_shell_ml/model_base.py | 4 +++- tf_shell_ml/postscale_sequential_model.py | 8 ++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index 3f9b280..135962d 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -46,7 +46,9 @@ def call(self, x, training=False): return x def compute_max_two_norm_and_pred(self, features, skip_two_norm): - with tf.GradientTape(persistent=tf.executing_eagerly() or self.jacobian_pfor) as tape: + with tf.GradientTape( + persistent=tf.executing_eagerly() or self.jacobian_pfor + ) as tape: y_pred = self(features, training=True) # forward pass if not skip_two_norm: @@ -272,4 +274,6 @@ def rebalance(x, s): else: result[key] = value # non-subdict elements are just copied - return result, None if self.disable_encryption else backprop_context.num_slots + return result, ( + None if self.disable_encryption else backprop_context.num_slots + ) diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index 192e447..82beba8 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -55,7 +55,9 @@ def __init__( self.dataset_prepped = False if self.disable_encryption and self.jacobian_pfor: - print("WARNING: `jacobian_pfor` may be incompatible with `disable_encryption`.") + print( + "WARNING: `jacobian_pfor` may be incompatible with `disable_encryption`." + ) def compile(self, shell_loss, **kwargs): if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy): diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index dc3836a..c87a493 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -57,7 +57,9 @@ def shell_train_step(self, data): # is factored out of the gradient computation and accounted for below. self.layers[-1].activation = tf.keras.activations.linear - with tf.GradientTape(persistent=tf.executing_eagerly() or self.jacobian_pfor) as tape: + with tf.GradientTape( + persistent=tf.executing_eagerly() or self.jacobian_pfor + ) as tape: y_pred = self(x, training=True) # forward pass grads = tape.jacobian( y_pred, @@ -273,4 +275,6 @@ def rebalance(x, s): else: result[key] = value # non-subdict elements are just copied - return result, None if self.disable_encryption else backprop_context.num_slots + return result, ( + None if self.disable_encryption else backprop_context.num_slots + )