Skip to content

Commit

Permalink
Fix pfor when computing jacobian.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Nov 1, 2024
1 parent 86ae436 commit 8c761f5
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 59 deletions.
56 changes: 28 additions & 28 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def call(self, x, training=False):
return x

def compute_max_two_norm_and_pred(self, features, skip_two_norm):
with tf.GradientTape(persistent=tf.executing_eagerly()) as tape:
with tf.GradientTape(persistent=tf.executing_eagerly() or self.jacobian_pfor) as tape:
y_pred = self(features, training=True) # forward pass

if not skip_two_norm:
grads = tape.jacobian(
y_pred,
self.trainable_variables,
unconnected_gradients="zero",
# unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor
parallel_iterations=self.jacobian_pfor_iterations,
experimental_use_pfor=self.jacobian_pfor,
)
Expand Down Expand Up @@ -244,32 +244,32 @@ def rebalance(x, s):
# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))

for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
if self.disable_encryption:
metric.update_state(y, y_pred)
if self.disable_encryption:
metric.update_state(y, y_pred)
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}

# TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics
# passed to compile in it's own dictionary. Keras wants all metrics to
# be returned as a flat dictionary. Here we flatten the dictionary.
result = {}
for key, value in metric_results.items():
if isinstance(value, dict):
result.update(value) # add subdict directly into the dict
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}

# TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics
# passed to compile in it's own dictionary. Keras wants all metrics to
# be returned as a flat dictionary. Here we flatten the dictionary.
result = {}
for key, value in metric_results.items():
if isinstance(value, dict):
result.update(value) # add subdict directly into the dict
else:
result[key] = value # non-subdict elements are just copied
result[key] = value # non-subdict elements are just copied

return result, None if self.disable_encryption else backprop_context.num_slots
return result, None if self.disable_encryption else backprop_context.num_slots
3 changes: 3 additions & 0 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(
self.check_overflow_INSECURE = check_overflow_INSECURE
self.dataset_prepped = False

if self.disable_encryption and self.jacobian_pfor:
print("WARNING: `jacobian_pfor` may be incompatible with `disable_encryption`.")

def compile(self, shell_loss, **kwargs):
if not isinstance(shell_loss, tf_shell_ml.CategoricalCrossentropy):
raise ValueError(
Expand Down
57 changes: 28 additions & 29 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,15 @@ def shell_train_step(self, data):
# is factored out of the gradient computation and accounted for below.
self.layers[-1].activation = tf.keras.activations.linear

with tf.GradientTape(persistent=tf.executing_eagerly()) as tape:
with tf.GradientTape(persistent=tf.executing_eagerly() or self.jacobian_pfor) as tape:
y_pred = self(x, training=True) # forward pass
grads = tape.jacobian(
y_pred,
self.trainable_variables,
unconnected_gradients="zero",
# unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor
parallel_iterations=self.jacobian_pfor_iterations,
experimental_use_pfor=self.jacobian_pfor,
)
# grads = tape.jacobian(y_pred, self.trainable_variables, experimental_use_pfor=True)
# ^ layers list x (batch size x num output classes x weights) matrix
# dy_pred_j/dW_sample_class

Expand Down Expand Up @@ -246,32 +245,32 @@ def rebalance(x, s):
# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))

for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
if self.disable_encryption:
metric.update_state(y, y_pred)
if self.disable_encryption:
metric.update_state(y, y_pred)
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}

# TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics
# passed to compile in it's own dictionary. Keras wants all metrics to
# be returned as a flat dictionary. Here we flatten the dictionary.
result = {}
for key, value in metric_results.items():
if isinstance(value, dict):
result.update(value) # add subdict directly into the dict
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}

# TensorFlow 2.18.0 added a "CompiledMetrics" metric which holds metrics
# passed to compile in it's own dictionary. Keras wants all metrics to
# be returned as a flat dictionary. Here we flatten the dictionary.
result = {}
for key, value in metric_results.items():
if isinstance(value, dict):
result.update(value) # add subdict directly into the dict
else:
result[key] = value # non-subdict elements are just copied
result[key] = value # non-subdict elements are just copied

return result, None if self.disable_encryption else backprop_context.num_slots
return result, None if self.disable_encryption else backprop_context.num_slots
2 changes: 0 additions & 2 deletions tf_shell_ml/test/dpsgd_conv_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
disable_noise=disable_noise,
cache_path=cache,
# check_overflow_INSECURE=True,
# jacobian_pfor=True,
# jacobian_pfor_iterations=128,
)

m.compile(
Expand Down

0 comments on commit 8c761f5

Please sign in to comment.