Skip to content

Commit

Permalink
Models support checking for overflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 24, 2024
1 parent f2bc5a0 commit d431069
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 76 deletions.
82 changes: 54 additions & 28 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow=False,
check_overflow_INSECURE=False,
cache_path=None,
jacobian_pfor=False,
jacobian_pfor_iterations=None,
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
self.disable_encryption = disable_encryption
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow = check_overflow
self.check_overflow_INSECURE = check_overflow_INSECURE
self.cache_path = cache_path
self.jacobian_pfor = jacobian_pfor
self.jacobian_pfor_iterations = jacobian_pfor_iterations
Expand Down Expand Up @@ -99,7 +99,7 @@ def compute_max_two_norm_and_pred(self, features, skip_two_norm):
grads = tape.jacobian(
y_pred,
self.trainable_variables,
unconnected_gradients='zero',
unconnected_gradients="zero",
parallel_iterations=self.jacobian_pfor_iterations,
experimental_use_pfor=self.jacobian_pfor,
)
Expand Down Expand Up @@ -143,8 +143,8 @@ def train_step(self, data):

# Backward pass.
dx = self.loss_fn.grad(enc_y, y_pred)
dJ_dw = []
dJ_dx = [dx]
dJ_dw = [] # Derivatives of the loss with respect to the weights.
dJ_dx = [dx] # Derivatives of the loss with respect to the inputs.
for l in reversed(self.layers):
if isinstance(l, tf_shell_ml.GlobalAveragePooling1D):
dw, dx = l.backward(dJ_dx[-1])
Expand All @@ -153,6 +153,29 @@ def train_step(self, data):
dJ_dw.extend(dw)
dJ_dx.append(dx)

# Check if the backproped gradients overflowed.
if not self.disable_encryption and self.check_overflow_INSECURE:
# Note, checking the backprop gradients requires decryption
# on the features party which breaks security of the protocol.
bp_scaling_factors = [g._scaling_factor for g in dJ_dw]
dec_dJ_dw = [
tf_shell.to_tensorflow(
g,
(
backprop_secret_fastrot_key
if g._is_fast_rotated
else backprop_secret_key
),
)
for g in dJ_dw
]
self.warn_on_overflow(
dec_dJ_dw,
bp_scaling_factors,
backprop_context.plaintext_modulus,
"WARNING: Backprop gradient may have overflowed.",
)

# Mask the encrypted grads to prepare for decryption. The masks may
# overflow during the reduce_sum over the batch. When the masks are
# operated on, they are multiplied by the scaling factor, so it is
Expand Down Expand Up @@ -214,9 +237,16 @@ def train_step(self, data):
grads = [f(g) for f, g in zip(self.unpacking_funcs, packed_grads)]

if not self.disable_noise:
# Efficiently pack the masked gradients to prepare for
# encryption. This is special because the masked gradients are
# no longer batched so the packing must be done manually.
tf.assert_equal(
backprop_context.num_slots,
noise_context.num_slots,
message="Backprop and noise contexts must have the same number of slots.",
)

# Efficiently pack the masked gradients to prepare for adding
# the encrypted noise. This is special because the masked
# gradients are no longer batched, so the packing must be done
# manually.
(
flat_grads,
grad_shapes,
Expand All @@ -240,6 +270,16 @@ def train_step(self, data):
# The gradients must be first be decrypted using the noise
# secret key.
flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key)

if self.check_overflow_INSECURE:
nosie_scaling_factors = grads._scaling_factor
self.warn_on_overflow(
[flat_grads],
[nosie_scaling_factors],
noise_context.plaintext_modulus,
"WARNING: Noised gradient may have overflowed.",
)

# Unpack the noise after decryption.
grads = self.unflatten_and_unpad_grad(
flat_grads,
Expand All @@ -249,10 +289,12 @@ def train_step(self, data):
)

if not self.disable_masking and not self.disable_encryption:
# SHELL represents floats as integers between [0, t) where t is the
# plaintext modulus. To mimic the modulo operation without SHELL,
# numbers which exceed the range [-t/2, t/2) are shifted back into
# the range.
# SHELL represents floats as integers between [0, t) where t is
# the plaintext modulus. To mimic the modulo operation without
# SHELL, numbers which exceed the range [-t/2, t/2) are shifted
# back into the range. In this context, t is the plaintext
# modulus of the backprop context, since that what the gradients
# were encrypted with when the mask was added.
epsilon = tf.constant(1e-6, dtype=float)

def rebalance(x, s):
Expand All @@ -274,22 +316,6 @@ def rebalance(x, s):
grads = [mg - m for mg, m in zip(grads, unpacked_mask)]
grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)]

if not self.disable_encryption and self.check_overflow:
# If the unmasked gradient is between [-t/2, -t/4] or
# [t/4, t/2], the gradient may have overflowed. Note this must
# also take the scaling factor into account.
overflowed = [
tf.abs(g) > t_half / 2 / s
for g, s in zip(grads, mask_scaling_factors)
]
overflowed = [tf.reduce_any(o) for o in overflowed]
overflowed = tf.reduce_any(overflowed)
tf.cond(
overflowed,
lambda: tf.print("WARNING: Gradient may have overflowed."),
lambda: tf.identity(overflowed),
)

# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))

Expand Down
27 changes: 27 additions & 0 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,30 @@ def unflatten_and_unpad_grad(
# Reshape to the original shapes.
grads_list = [tf.reshape(g, s) for g, s in zip(grads_list, grad_shapes)]
return grads_list

def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message):
# If the gradient is between [-t/2, -t/4] or [t/4, t/2], the gradient
# may have overflowed. This also must take the scaling factor into
# account so the range is divided by the scaling factor.
t = tf.cast(plaintext_modulus, grads[0].dtype)
t_half = t / 2

over_by = [
tf.reduce_max(tf.abs(g) - t_half / 2 / s)
for g, s in zip(grads, scaling_factors)
]
max_over_by = tf.reduce_max(over_by)
overflowed = tf.reduce_any(max_over_by > 0)

tf.cond(
overflowed,
lambda: tf.print(
message,
"Overflowed by",
over_by,
"(positive number indicates overflow amount).",
),
lambda: tf.identity(overflowed),
)

return overflowed
62 changes: 41 additions & 21 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def __init__(
disable_encryption=False,
disable_masking=False,
disable_noise=False,
check_overflow=False,
check_overflow_INSECURE=False,
cache_path=None,
jacobian_pfor=False,
jacobian_pfor_iterations=None,
*args,
**kwargs,
):
Expand All @@ -59,8 +61,10 @@ def __init__(
self.noise_multiplier = noise_multiplier
self.disable_masking = disable_masking
self.disable_noise = disable_noise
self.check_overflow = check_overflow
self.check_overflow_INSECURE = check_overflow_INSECURE
self.cache_path = cache_path
self.jacobian_pfor = jacobian_pfor
self.jacobian_pfor_iterations = jacobian_pfor_iterations

def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs):
super().compile(optimizer=optimizer, loss=loss, metrics=metrics, **kwargs)
Expand Down Expand Up @@ -101,8 +105,9 @@ def train_step(self, data):
grads = tape.jacobian(
y_pred,
self.trainable_variables,
parallel_iterations=1,
experimental_use_pfor=False,
unconnected_gradients="zero",
parallel_iterations=self.jacobian_pfor_iterations,
experimental_use_pfor=self.jacobian_pfor,
)
# grads = tape.jacobian(y_pred, self.trainable_variables, experimental_use_pfor=True)
# ^ layers list x (batch size x num output classes x weights) matrix
Expand Down Expand Up @@ -147,6 +152,22 @@ def train_step(self, data):
)
# ^ layers list (batch_size x weights)

# Check if the post-scaled gradients overflowed.
if not self.disable_encryption and self.check_overflow_INSECURE:
# Note, checking the backprop gradients requires decryption
# on the features party which breaks security of the protocol.
bp_scaling_factors = [g._scaling_factor for g in grads]
dec_grads = [
tf_shell.to_tensorflow(g, backprop_secret_fastrot_key)
for g in grads
]
self.warn_on_overflow(
dec_grads,
bp_scaling_factors,
backprop_context.plaintext_modulus,
"WARNING: Backprop gradient may have overflowed.",
)

# Mask the encrypted grads to prepare for decryption. The masks may
# overflow during the reduce_sum over the batch. When the masks are
# operated on, they are multiplied by the scaling factor, so it is
Expand Down Expand Up @@ -195,6 +216,12 @@ def train_step(self, data):
]

if not self.disable_noise:
tf.assert_equal(
backprop_context.num_slots,
noise_context.num_slots,
message="Backprop and noise contexts must have the same number of slots.",
)

# Efficiently pack the masked gradients to prepare for
# encryption. This is special because the masked gradients are
# no longer batched so the packing must be done manually.
Expand All @@ -217,7 +244,16 @@ def train_step(self, data):
if not self.disable_noise:
# The gradients must be first be decrypted using the noise
# secret key.
grads = tf_shell.to_tensorflow(grads, noise_secret_key)
flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key)

if not self.disable_encryption and self.check_overflow_INSECURE:
nosie_scaling_factor = grads._scaling_factor
self.warn_on_overflow(
[flat_grads],
[nosie_scaling_factor],
noise_context.plaintext_modulus,
"WARNING: Noised gradient may have overflowed.",
)
# Unpack the noise after decryption.
grads = self.unflatten_and_unpad_grad(
flat_grads,
Expand Down Expand Up @@ -252,22 +288,6 @@ def rebalance(x, s):

grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)]

if not self.disable_encryption and self.check_overflow:
# If the unmasked gradient is between [-t/2, -t/4] or
# [t/4, t/2], the gradient may have overflowed. Note this must
# also take the scaling factor into account.
overflowed = [
tf.abs(g) > t_half / 2 / s
for g, s in zip(grads, mask_scaling_factors)
]
overflowed = [tf.reduce_any(o) for o in overflowed]
overflowed = tf.reduce_any(overflowed)
tf.cond(
overflowed,
lambda: tf.print("Gradient may have overflowed"),
lambda: tf.identity(overflowed),
)

# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))

Expand Down
20 changes: 11 additions & 9 deletions tf_shell_ml/test/dpsgd_conv_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,24 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):
),
],
backprop_context_fn=lambda: tf_shell.create_autocontext64(
log2_cleartext_sz=26,
scaling_factor=2,
noise_offset_log2=46,
log2_cleartext_sz=24,
scaling_factor=1,
noise_offset_log2=0,
cache_path=context_cache_path,
),
noise_context_fn=lambda: tf_shell.create_autocontext64(
log2_cleartext_sz=26,
scaling_factor=2,
noise_offset_log2=47,
scaling_factor=1,
noise_offset_log2=0,
cache_path=context_cache_path,
),
disable_encryption=disable_encryption,
disable_masking=disable_masking,
disable_noise=disable_noise,
cache_path=context_cache_path,
check_overflow=True,
jacobian_pfor=False,
jacobian_pfor_iterations=None,
# check_overflow_INSECURE=True,
# jacobian_pfor=True,
# jacobian_pfor_iterations=128,
)

m.compile(
Expand All @@ -110,7 +110,9 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise):
m.build([None, 28, 28, 1])
m.summary()

history = m.fit(train_dataset.take(9), epochs=1, validation_data=val_dataset)
history = m.fit(train_dataset.take(16), epochs=1, validation_data=val_dataset)

context_cache.cleanup()

self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.30)

Expand Down
8 changes: 4 additions & 4 deletions tf_shell_ml/test/dpsgd_model_distrib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def test_model(self):
),
],
lambda: tf_shell.create_autocontext64(
log2_cleartext_sz=14,
scaling_factor=8,
noise_offset_log2=12,
log2_cleartext_sz=10,
scaling_factor=5,
noise_offset_log2=0,
cache_path=context_cache_path,
),
lambda: tf_shell.create_autocontext64(
log2_cleartext_sz=14,
log2_cleartext_sz=11,
scaling_factor=8,
noise_offset_log2=0,
cache_path=context_cache_path,
Expand Down
Loading

0 comments on commit d431069

Please sign in to comment.