diff --git a/tf_shell_ml/test/dpsgd_conv_model_local_test.py b/tf_shell_ml/test/dpsgd_conv_model_local_test.py index 1decd62..3c9ba9a 100644 --- a/tf_shell_ml/test/dpsgd_conv_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_conv_model_local_test.py @@ -69,7 +69,8 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) tf_shell_ml.Flatten(), tf_shell_ml.ShellDense( 16, - activation=tf.nn.softmax, + activation=tf_shell_ml.relu, + activation_deriv=tf_shell_ml.relu_deriv, ), tf_shell_ml.ShellDense( 10,