diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index dbe4949..cbe9069 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -247,18 +247,18 @@ def rebalance(x, s): for metric in self.metrics: if metric.name == "loss": if self.disable_encryption: - metric.update_state(0) # Loss is not known when encrypted. - else: loss = self.loss_fn(y, y_pred) metric.update_state(loss) + else: + # Loss is unknown when encrypted. + metric.update_state(0.0) else: if self.disable_encryption: metric.update_state(y, y_pred) else: + # Other metrics are uknown when encrypted. zeros = tf.broadcast_to(0, tf.shape(y_pred)) - metric.update_state( - zeros, zeros - ) # No other metrics when encrypted. + metric.update_state(zeros, zeros) metric_results = {m.name: m.result() for m in self.metrics} diff --git a/tf_shell_ml/loss.py b/tf_shell_ml/loss.py index ef1cb92..d52a45d 100644 --- a/tf_shell_ml/loss.py +++ b/tf_shell_ml/loss.py @@ -16,7 +16,7 @@ from tensorflow.nn import softmax from tensorflow.nn import sigmoid from tensorflow.math import log -import tf_shell +import tensorflow as tf class CategoricalCrossentropy: @@ -27,10 +27,9 @@ def __init__(self, from_logits=False, lazy_normalization=True): def __call__(self, y_true, y_pred): if self.from_logits: y_pred = softmax(y_pred) - batch_size = y_true.shape.as_list()[0] - batch_size_inv = 1 / batch_size + batch_size = tf.shape(y_true)[0] out = -y_true * log(y_pred) - cce = tf_shell.reduce_sum(out, axis=0) * batch_size_inv + cce = tf.reduce_sum(out, axis=0) / tf.cast(batch_size, out.dtype) return cce def grad(self, y_true, y_pred): @@ -61,10 +60,9 @@ def __call__(self, y_true, y_pred): if self.from_logits: y_pred = sigmoid(y_pred) - batch_size = y_true.shape.as_list()[0] - batch_size_inv = 1 / batch_size + batch_size = tf.shape(y_true)[0] out = -(y_true * log(y_pred) + (1 - y_true) * log(1 - y_pred)) - bce = tf_shell.reduce_sum(out, axis=0) * batch_size_inv + bce = tf.reduce_sum(out, axis=0) / tf.cast(batch_size, out.dtype) return bce def grad(self, y_true, y_pred): diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index a861461..b9b9344 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -249,18 +249,18 @@ def rebalance(x, s): for metric in self.metrics: if metric.name == "loss": if self.disable_encryption: - metric.update_state(0) # Loss is not known when encrypted. - else: loss = self.loss_fn(y, y_pred) metric.update_state(loss) + else: + # Loss is unknown when encrypted. + metric.update_state(0.0) else: if self.disable_encryption: metric.update_state(y, y_pred) else: + # Other metrics are uknown when encrypted. zeros = tf.broadcast_to(0, tf.shape(y_pred)) - metric.update_state( - zeros, zeros - ) # No other metrics when encrypted. + metric.update_state(zeros, zeros) metric_results = {m.name: m.result() for m in self.metrics}