Skip to content

Commit

Permalink
Sequential models cache context, keys, and check for overflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 11, 2024
1 parent 6105513 commit 3822693
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 35 deletions.
33 changes: 23 additions & 10 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import tensorflow.keras as keras
import tf_shell
import tf_shell_ml
import tempfile
from tf_shell_ml.model_base import SequentialBase


Expand All @@ -33,7 +32,8 @@ def __init__(
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow_close=True,
check_overflow=False,
cache_path=None,
*args,
**kwargs,
):
Expand All @@ -58,7 +58,8 @@ def __init__(
self.disable_encryption = disable_encryption
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow_close = check_overflow_close
self.check_overflow = check_overflow
self.cache_path = cache_path

def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs):
super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs)
Expand Down Expand Up @@ -115,18 +116,14 @@ def train_step(self, data):
if self.disable_encryption:
enc_y = y
else:
key_path = tempfile.mkdtemp() # Every trace gets a new key.

shell_context = self.shell_context_fn()
secret_key = tf_shell.create_key64(
shell_context, key_path + "/secret_key"
)
secret_key = tf_shell.create_key64(shell_context, self.cache_path)
secret_fast_rotation_key = tf_shell.create_fast_rotation_key64(
shell_context, secret_key, key_path + "/secret_fast_rotation_key"
shell_context, secret_key, self.cache_path
)
if self.needs_public_rotation_key:
public_rotation_key = tf_shell.create_rotation_key64(
shell_context, secret_key, key_path + "/public_rotation_key"
shell_context, secret_key, self.cache_path
)
# Encrypt the batch of secret labels y.
enc_y = tf_shell.to_encrypted(y, secret_key, shell_context)
Expand Down Expand Up @@ -238,6 +235,22 @@ def rebalance(x, s):
grads = [mg - m for mg, m in zip(masked_grads, unpacked_mask)]
grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)]

if not self.disable_encryption and self.check_overflow:
# If the unmasked gradient is between [-t/2, -t/4] or
# [t/4, t/2], the gradient may have overflowed. Note this must
# also take the scaling factor into account.
overflowed = [
tf.abs(g) > t_half / 2 / s
for g, s in zip(batch_grad, mask_scaling_factors)
]
overflowed = [tf.reduce_any(o) for o in overflowed]
overflowed = tf.reduce_any(overflowed)
tf.cond(
overflowed,
lambda: tf.print("Gradient may have overflowed"),
lambda: tf.identity(overflowed),
)

# TODO: set stddev based on clipping threshold.
if self.disable_noise:
noised_grads = grads
Expand Down
22 changes: 12 additions & 10 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tensorflow as tf
import tensorflow.keras as keras
import tf_shell
import tempfile
from tf_shell_ml.model_base import SequentialBase


Expand All @@ -31,7 +30,8 @@ def __init__(
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow_close=True,
check_overflow=False,
cache_path=None,
*args,
**kwargs,
):
Expand All @@ -57,7 +57,8 @@ def __init__(
self.noise_multiplier = noise_multiplier
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow_close = check_overflow_close
self.check_overflow = check_overflow
self.cache_path = cache_path

def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs):
super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs)
Expand All @@ -82,11 +83,8 @@ def train_step(self, data):
if self.disable_encryption:
enc_y = y
else:
key_path = tempfile.mkdtemp() # Every trace gets a new key.
shell_context = self.shell_context_fn()
secret_key = tf_shell.create_key64(
shell_context, key_path + "/secret_key"
)
secret_key = tf_shell.create_key64(shell_context, self.cache_path)
# Encrypt the batch of secret labels y.
enc_y = tf_shell.to_encrypted(y, secret_key, shell_context)

Expand Down Expand Up @@ -243,10 +241,14 @@ def rebalance(x, s):
rebalance(g, s) for g, s in zip(batch_grad, mask_scaling_factors)
]

if not self.disable_encryption and self.check_overflow_close:
if not self.disable_encryption and self.check_overflow:
# If the unmasked gradient is between [-t/2, -t/4] or
# [t/4, t/2], the gradient may have overflowed.
overflowed = [tf.abs(g) > t_half / 2 for g in batch_grad]
# [t/4, t/2], the gradient may have overflowed. Note this must
# also take the scaling factor into account.
overflowed = [
tf.abs(g) > t_half / 2 / s
for g, s in zip(batch_grad, mask_scaling_factors)
]
overflowed = [tf.reduce_any(o) for o in overflowed]
overflowed = tf.reduce_any(overflowed)
tf.cond(
Expand Down
6 changes: 3 additions & 3 deletions tf_shell_ml/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ py_test(

py_test(
name = "dpsgd_model_local_test",
size = "large",
size = "enormous",
srcs = ["dpsgd_model_local_test.py"],
tags = ["exclusive"],
deps = [
Expand All @@ -65,7 +65,7 @@ py_test(

py_test(
name = "dpsgd_model_distrib_test",
size = "large",
size = "enormous",
srcs = ["dpsgd_model_distrib_test.py"],
tags = ["exclusive"],
deps = [
Expand All @@ -76,7 +76,7 @@ py_test(

py_test(
name = "postscale_model_local_test",
size = "large",
size = "enormous",
srcs = ["postscale_model_local_test.py"],
tags = ["exclusive"],
deps = [
Expand Down
19 changes: 9 additions & 10 deletions tf_shell_ml/test/dpsgd_model_distrib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,17 @@ def test_model(self):

# Clip dataset images to limit memory usage. The model accuracy will be
# bad but this test only measures functionality.
x_train, x_test = x_train[:, :120], x_test[:, :120]
x_train, x_test = x_train[:, :380], x_test[:, :380]

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=2**14).batch(2**12)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)

context_cache_path = "/tmp/postscale_model_distrib_test_cache/"
os.makedirs(context_cache_path, exist_ok=True)

m = tf_shell_ml.DpSgdSequential(
[
tf_shell_ml.ShellDense(
Expand All @@ -93,19 +96,15 @@ def test_model(self):
use_fast_reduce_sum=True,
),
],
# lambda: tf_shell.create_context64(
# log_n=12,
# main_moduli=[288230376151760897, 288230376152137729],
# plaintext_modulus=4294991873,
# scaling_factor=3,
# ),
lambda: tf_shell.create_autocontext64(
log2_cleartext_sz=32,
scaling_factor=3,
noise_offset_log2=25,
log2_cleartext_sz=14,
scaling_factor=8,
noise_offset_log2=12,
cache_path=context_cache_path,
),
labels_party_dev=labels_party_dev,
features_party_dev=features_party_dev,
cache_path=context_cache_path,
)

m.compile(
Expand Down
8 changes: 7 additions & 1 deletion tf_shell_ml/test/dpsgd_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import tf_shell
import tf_shell_ml
import os


class TestModel(tf.test.TestCase):
Expand All @@ -39,6 +40,9 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)

context_cache_path = "/tmp/postscale_model_local_test_cache/"
os.makedirs(context_cache_path, exist_ok=True)

# Turn on the shell optimizer to use autocontext.
tf_shell.enable_optimization()

Expand All @@ -60,10 +64,12 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):
log2_cleartext_sz=14,
scaling_factor=8,
noise_offset_log2=12,
cache_path=context_cache_path,
),
disable_encryption=False,
disable_masking=False,
disable_noise=False,
cache_path=context_cache_path,
)

m.compile(
Expand All @@ -79,7 +85,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):

history = m.fit(train_dataset.take(4), epochs=1, validation_data=val_dataset)

self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3)
self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.35)

def test_model(self):
self._test_model(False, False, False)
Expand Down
8 changes: 7 additions & 1 deletion tf_shell_ml/test/postscale_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import tf_shell
import tf_shell_ml
import os


class TestModel(tf.test.TestCase):
Expand All @@ -39,6 +40,9 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)

context_cache_path = "/tmp/postscale_model_local_test_cache/"
os.makedirs(context_cache_path, exist_ok=True)

m = tf_shell_ml.PostScaleSequential(
[
tf.keras.layers.Dense(64, activation="relu"),
Expand All @@ -48,10 +52,12 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):
log2_cleartext_sz=14,
scaling_factor=2**10,
noise_offset_log2=47, # may be overprovisioned
cache_path=context_cache_path,
),
disable_encryption=False,
disable_masking=False,
disable_noise=False,
cache_path=context_cache_path,
)

m.compile(
Expand All @@ -63,7 +69,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):

history = m.fit(train_dataset.take(4), epochs=1, validation_data=val_dataset)

self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3)
self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25)

def test_model(self):
self._test_model(False, False, False)
Expand Down

0 comments on commit 3822693

Please sign in to comment.