diff --git a/tf_shell_ml/test/dpsgd_conv_model_local_test.py b/tf_shell_ml/test/dpsgd_conv_model_local_test.py index 6438a34..5df0221 100644 --- a/tf_shell_ml/test/dpsgd_conv_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_conv_model_local_test.py @@ -32,6 +32,11 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0) y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10) + # Clip the input images to make testing faster. + clip_by = 8 + x_train = x_train[:, clip_by : (28 - clip_by), clip_by : (28 - clip_by), :] + x_test = x_test[:, clip_by : (28 - clip_by), clip_by : (28 - clip_by), :] + labels_dataset = tf.data.Dataset.from_tensor_slices(y_train) labels_dataset = labels_dataset.batch(2**12) @@ -44,7 +49,8 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) m = tf_shell_ml.DpSgdSequential( [ # Model from tensorflow-privacy tutorial. The first layer may - # be skipped and the model still has ~95% accuracy (plaintext). + # be skipped and the model still has ~95% accuracy (plaintext, + # no input clipping). # tf_shell_ml.Conv2D( # filters=16, # kernel_size=8, @@ -61,7 +67,6 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) filters=32, kernel_size=4, strides=2, - padding="valid", activation=tf_shell_ml.relu, activation_deriv=tf_shell_ml.relu_deriv, ), @@ -101,23 +106,23 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) m.compile( shell_loss=tf_shell_ml.CategoricalCrossentropy(), - optimizer=tf.keras.optimizers.Adam(0.1), + optimizer=tf.keras.optimizers.Adam(0.18), metrics=[tf.keras.metrics.CategoricalAccuracy()], ) - m.build([None, 28, 28, 1]) + m.build([None, 28 - (2 * clip_by), 28 - (2 * clip_by), 1]) m.summary() history = m.fit( features_dataset, labels_dataset, - steps_per_epoch=8, + steps_per_epoch=1, epochs=1, - verbose=2, + verbose=1, validation_data=val_dataset, ) - self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25) + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.13) def test_model(self): with tempfile.TemporaryDirectory() as cache_dir: