Skip to content

Commit

Permalink
Models should not compute loss when encrypted backprop enabled.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 30, 2024
1 parent da280db commit 44ece50
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
10 changes: 5 additions & 5 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,18 @@ def rebalance(x, s):
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
metric.update_state(0) # Loss is not known when encrypted.
else:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
if self.disable_encryption:
metric.update_state(y, y_pred)
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(
zeros, zeros
) # No other metrics when encrypted.
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}

Expand Down
12 changes: 5 additions & 7 deletions tf_shell_ml/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensorflow.nn import softmax
from tensorflow.nn import sigmoid
from tensorflow.math import log
import tf_shell
import tensorflow as tf


class CategoricalCrossentropy:
Expand All @@ -27,10 +27,9 @@ def __init__(self, from_logits=False, lazy_normalization=True):
def __call__(self, y_true, y_pred):
if self.from_logits:
y_pred = softmax(y_pred)
batch_size = y_true.shape.as_list()[0]
batch_size_inv = 1 / batch_size
batch_size = tf.shape(y_true)[0]
out = -y_true * log(y_pred)
cce = tf_shell.reduce_sum(out, axis=0) * batch_size_inv
cce = tf.reduce_sum(out, axis=0) / tf.cast(batch_size, out.dtype)
return cce

def grad(self, y_true, y_pred):
Expand Down Expand Up @@ -61,10 +60,9 @@ def __call__(self, y_true, y_pred):
if self.from_logits:
y_pred = sigmoid(y_pred)

batch_size = y_true.shape.as_list()[0]
batch_size_inv = 1 / batch_size
batch_size = tf.shape(y_true)[0]
out = -(y_true * log(y_pred) + (1 - y_true) * log(1 - y_pred))
bce = tf_shell.reduce_sum(out, axis=0) * batch_size_inv
bce = tf.reduce_sum(out, axis=0) / tf.cast(batch_size, out.dtype)
return bce

def grad(self, y_true, y_pred):
Expand Down
10 changes: 5 additions & 5 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,18 +249,18 @@ def rebalance(x, s):
for metric in self.metrics:
if metric.name == "loss":
if self.disable_encryption:
metric.update_state(0) # Loss is not known when encrypted.
else:
loss = self.loss_fn(y, y_pred)
metric.update_state(loss)
else:
# Loss is unknown when encrypted.
metric.update_state(0.0)
else:
if self.disable_encryption:
metric.update_state(y, y_pred)
else:
# Other metrics are uknown when encrypted.
zeros = tf.broadcast_to(0, tf.shape(y_pred))
metric.update_state(
zeros, zeros
) # No other metrics when encrypted.
metric.update_state(zeros, zeros)

metric_results = {m.name: m.result() for m in self.metrics}

Expand Down

0 comments on commit 44ece50

Please sign in to comment.