Skip to content

Commit

Permalink
Shrink dpsgd convolutional test.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Nov 4, 2024
1 parent ba7fff2 commit f86add2
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions tf_shell_ml/test/dpsgd_conv_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

# Clip the input images to make testing faster.
clip_by = 8
x_train = x_train[:, clip_by : (28 - clip_by), clip_by : (28 - clip_by), :]
x_test = x_test[:, clip_by : (28 - clip_by), clip_by : (28 - clip_by), :]

labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)

Expand All @@ -44,7 +49,8 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
m = tf_shell_ml.DpSgdSequential(
[
# Model from tensorflow-privacy tutorial. The first layer may
# be skipped and the model still has ~95% accuracy (plaintext).
# be skipped and the model still has ~95% accuracy (plaintext,
# no input clipping).
# tf_shell_ml.Conv2D(
# filters=16,
# kernel_size=8,
Expand All @@ -61,7 +67,6 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
filters=32,
kernel_size=4,
strides=2,
padding="valid",
activation=tf_shell_ml.relu,
activation_deriv=tf_shell_ml.relu_deriv,
),
Expand Down Expand Up @@ -101,23 +106,23 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)

m.compile(
shell_loss=tf_shell_ml.CategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(0.1),
optimizer=tf.keras.optimizers.Adam(0.18),
metrics=[tf.keras.metrics.CategoricalAccuracy()],
)

m.build([None, 28, 28, 1])
m.build([None, 28 - (2 * clip_by), 28 - (2 * clip_by), 1])
m.summary()

history = m.fit(
features_dataset,
labels_dataset,
steps_per_epoch=8,
steps_per_epoch=1,
epochs=1,
verbose=2,
verbose=1,
validation_data=val_dataset,
)

self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25)
self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.13)

def test_model(self):
with tempfile.TemporaryDirectory() as cache_dir:
Expand Down

0 comments on commit f86add2

Please sign in to comment.