Skip to content

Commit

Permalink
Operations which accept either pt or ct will mimic tf-shell's ct beha…
Browse files Browse the repository at this point in the history
…vior when passed pt.
  • Loading branch information
james-choncholas committed Oct 26, 2024
1 parent 0f7a688 commit a682db2
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 111 deletions.
77 changes: 71 additions & 6 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def to_tensorflow(s_tensor, key=None):
)


def roll(x, shift, rotation_key):
def roll(x, shift, rotation_key=None):
if isinstance(x, ShellTensor64):
if not isinstance(rotation_key, ShellRotationKey64):
raise ValueError(
Expand Down Expand Up @@ -734,7 +734,16 @@ def roll(x, shift, rotation_key):
_is_fast_rotated=x._is_fast_rotated,
)
elif isinstance(x, tf.Tensor):
return tf.roll(x, shift)
# TensorFlow's roll has slightly different semantics than tf-shell's
# roll. Encrypted rotation affects top and bottom halves independently.
# This function emulates this in plaintext by splitting the tensor in
# half, rotating each half, and then concatenating them back together.
top, bottom = tf.split(x, num_or_size_splits=2, axis=0)
top = tf.roll(top, shift, axis=0)
bottom = tf.roll(bottom, shift, axis=0)
rotated_tftensor = tf.concat([top, bottom], axis=0)
return rotated_tftensor

else:
raise ValueError(f"Unsupported type for roll. Got {type(x)}.")

Expand Down Expand Up @@ -786,7 +795,22 @@ def reduce_sum(x, axis, rotation_key=None):
_is_fast_rotated=x._is_fast_rotated,
)
elif isinstance(x, tf.Tensor):
return tf.reduce_sum(x, axis)
if axis == 0:
# TensorFlow's reduce_sum over axis 0 (the slotting dimension) has
# slightly different semantics than tf-shell's reduce_sum. Encrypted
# reduce_sum affects top and bottom halves independently, as well as
# repeating the sum across the halves. This emulates this in
# plaintext.
half_slots = x.shape[0] // 2
bottom_answer = tf.math.reduce_sum(x[0:half_slots], axis=0, keepdims=True)
top_answer = tf.math.reduce_sum(x[half_slots:], axis=0, keepdims=True)

repeated_bottom_answer = tf.repeat(bottom_answer, repeats=half_slots, axis=0)
repeated_top_answer = tf.repeat(top_answer, repeats=half_slots, axis=0)

return tf.concat([repeated_bottom_answer, repeated_top_answer], 0)
else:
return tf.reduce_sum(x, axis)
else:
raise ValueError(f"Unsupported type for reduce_sum. Got {type(x)}.")

Expand Down Expand Up @@ -842,7 +866,7 @@ def fast_reduce_sum(x):
)


def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"):
def matmul(x, y, rotation_key=None, pt_ct_reduction="galois", emulate_pt_ct=False):
"""Matrix multiplication is specialized to whether the operands are
plaintext or ciphertext.
Expand Down Expand Up @@ -935,7 +959,26 @@ def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"):
return NotImplementedError

elif isinstance(x, tf.Tensor) and isinstance(y, tf.Tensor):
return tf.matmul(x, y)
if emulate_pt_ct:
# tf-shell matmult has slightly different semantics than plaintext /
# Tensorflow. Encrypted matmult affects top and bottom halves
# independently, as well as the first dimension repeating the sum of
# either the halves. This function emulates this in plaintext with
# element-wise multiplication, and an optional reduction.
shape_range = range(len(x.shape))
x = tf.transpose(x, perm=[shape_range[-1]] + list(shape_range[:-1]))
x = tf.expand_dims(x, axis=-1)
for _ in range(len(x.shape) - 2):
y = tf.expand_dims(y, axis=-2)
res = x * y

if pt_ct_reduction != "none":
res = reduce_sum(res, axis=0)

return res

else:
return tf.matmul(x, y)

else:
raise ValueError(
Expand Down Expand Up @@ -1098,7 +1141,29 @@ def segment_sum(x, segments, num_segments, rotation_key=None, reduction="galois"
reduction_count,
)
elif isinstance(x, tf.Tensor):
return tf.math.unsorted_segment_sum(x, segments, num_segments)
# tf-shell segment functions differs from tensorflow in the following
# ways: First, the ciphertext dimension is included in the output, but
# only one dimension is valid. For the top half of the ciphertext, the
# first dimension is valid, and for the bottom half, the `num_slots //
# 2`th dimension is valid.
# Second, the reduction only happens across half of the batching
# dimension, due to how rotations in tf-shell work. Segment reduction
# happens on the top and bottom halves of the ciphertext independently.
if reduction == "none":
raise ValueError(
"Plaintext segment_sum does not support `none` reduction."
)
half_slots = x.shape[0] // 2
padding = tf.zeros_like(x[:half_slots])

x_top = tf.concat([x[:half_slots], padding], 0)
x_bottom = tf.concat([padding, x[half_slots:]], 0)

top_answer = tf.math.unsorted_segment_sum(x_top, segments, num_segments)
bottom_answer = tf.math.unsorted_segment_sum(x_bottom, segments, num_segments)

return top_answer, bottom_answer

else:
raise ValueError("Unsupported type for segment_sum")

Expand Down
2 changes: 1 addition & 1 deletion tf_shell/test/distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_distribution(self):
self.assertAllClose(c, a + b, atol=1)
self.assertAllClose(e, a + d, atol=1)
self.assertAllClose(
f, test_utils.plaintext_reduce_sum_axis_0(a + d), atol=1, rtol=1e-2
f, tf_shell.reduce_sum(a + d, axis=0), atol=1, rtol=1e-2
)


Expand Down
35 changes: 1 addition & 34 deletions tf_shell/test/mat_mul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,39 +117,6 @@ def test_ct_tf_matmul(self):
tf.config.run_functions_eagerly(eager)
self._test_ct_tf_matmul(test_context)

# tf-shell matmult has slightly different semantics than plaintext /
# Tensorflow. Encrypted matmult affects top and bottom halves independently,
# as well as the first dimension repeating the sum of either the halves.
# This function emulates this in plaintext.
def plaintext_matmul(self, a, b):
half_slots = b.shape[0] // 2
tf_half_slots = tf.constant([half_slots], dtype=tf.int32)

a_shape = tf.shape(a)
a_top_start = tf.zeros_like(a_shape)
a_top_shape = tf.concat([a_shape[:-1], tf_half_slots], axis=0)
a_top = tf.slice(a, a_top_start, a_top_shape)
a_bottom_start = tf.concat(
[tf.zeros_like(a_top_start[:-1]), tf_half_slots], axis=0
)
a_bottom_shape = tf.concat([a_shape[:-1], tf_half_slots], axis=0)
a_bottom = tf.slice(a, a_bottom_start, a_bottom_shape)

assert len(tf.shape(b)) == 2
b_top = b[:half_slots, :]
b_bottom = b[half_slots:, :]

top = tf.matmul(a_top, b_top)
bottom = tf.matmul(a_bottom, b_bottom)

top = tf.expand_dims(top, axis=0)
bottom = tf.expand_dims(bottom, axis=0)

top = tf.repeat(top, repeats=[half_slots], axis=0)
bottom = tf.repeat(bottom, repeats=[half_slots], axis=0)

return tf.concat([top, bottom], axis=0)

def _test_tf_ct_matmul(self, test_context, use_fast_rotation):
# Generating the following tensors should always succeed since this test
# uses it's own special context.
Expand All @@ -174,7 +141,7 @@ def _test_tf_ct_matmul(self, test_context, use_fast_rotation):
return

eb = tf_shell.to_encrypted(b, test_context.key, test_context.shell_context)
check_c = self.plaintext_matmul(a, b)
check_c = tf_shell.matmul(a, b, emulate_pt_ct=True)

@tf.function
def test_functor():
Expand Down
6 changes: 3 additions & 3 deletions tf_shell/test/rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _test_roll(self, test_context, roll_num):
tftensor, multiples=[1] * (i + 1) + [test_context.outer_shape[i]]
)

rolled_tftensor = test_utils.plaintext_roll(tftensor, roll_num)
rolled_tftensor = tf_shell.roll(tftensor, roll_num)

s = tf_shell.to_shell_plaintext(tftensor, test_context.shell_context)
enc = tf_shell.to_encrypted(s, test_context.key)
Expand Down Expand Up @@ -148,7 +148,7 @@ def _test_roll_mod_reduced(self, test_context, roll_num):
tftensor, multiples=[1] * (i + 1) + [test_context.outer_shape[i]]
)

rolled_tftensor = test_utils.plaintext_roll(tftensor, roll_num)
rolled_tftensor = tf_shell.roll(tftensor, roll_num)

s = tf_shell.to_shell_plaintext(tftensor, test_context.shell_context)
enc = tf_shell.to_encrypted(s, test_context.key)
Expand Down Expand Up @@ -205,7 +205,7 @@ def _test_reduce_sum_axis_0(self, test_context):

tftensor_out = tf_shell.to_tensorflow(enc_reduce_sum, test_context.key)
self.assertAllClose(
tftensor_out, test_utils.plaintext_reduce_sum_axis_0(tftensor), atol=1e-3
tftensor_out, tf_shell.reduce_sum(tftensor, axis=0), atol=1e-3
)

def test_reduce_sum_axis_0(self):
Expand Down
2 changes: 1 addition & 1 deletion tf_shell/test/rotation_test_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _test_fast_reduce_sum_axis_0(self, test_context):
enc_reduce_sum, key=test_context.fast_rotation_key
)
self.assertAllClose(
tftensor_out, test_utils.plaintext_reduce_sum_axis_0(tftensor), atol=1e-3
tftensor_out, tf_shell.reduce_sum(tftensor, axis=0), atol=1e-3
)

def test_fast_reduce_sum_axis_0(self):
Expand Down
40 changes: 2 additions & 38 deletions tf_shell/test/segment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,6 @@ def setUpClass(cls):
def tearDownClass(cls):
cls.rotation_test_contexts = None

# tf-shell segment functions differs from tensorflow in the following ways:
# First, the ciphertext dimension is included in the output, but only one
# dimension is valid. For the top half of the ciphertext, the first
# dimension is valid, and for the bottom half, the `num_slots // 2`th
# dimension is valid.
# Second, the reduction only happens across half of the batching dimension,
# due to how rotations in tf-shell work. Segment reduction happens on the
# top and bottom halves of the ciphertext independently.
def plaintext_segment_sum(self, x, segments, num_segments, start_segment=0):
half_slots = x.shape[0] // 2
padding = tf.zeros_like(x[:half_slots])

x_top = tf.concat([x[:half_slots], padding], 0)
x_bottom = tf.concat([padding, x[half_slots:]], 0)

top_answer = tf.math.unsorted_segment_sum(x_top, segments, num_segments)
bottom_answer = tf.math.unsorted_segment_sum(x_bottom, segments, num_segments)

if start_segment > 0:
top_answer = tf.concat(
[
tf.zeros_like(top_answer[:start_segment]),
top_answer[start_segment:],
],
axis=0,
)
bottom_answer = tf.concat(
[
tf.zeros_like(bottom_answer[:start_segment]),
bottom_answer[start_segment:],
],
axis=0,
)

return top_answer, bottom_answer

def create_rand_data(self, test_context, repeats):
try:
shape_prod = math.prod(test_context.outer_shape)
Expand Down Expand Up @@ -163,7 +127,7 @@ def test_functor(ea, segments, num_segments, rot_key):

ss = tf_shell.to_tensorflow(ess, test_context.key)

pt_ss_top, pt_ss_bottom = self.plaintext_segment_sum(a, segments, num_segments)
pt_ss_top, pt_ss_bottom = tf_shell.segment_sum(a, segments, num_segments)

# Ensure the reduced data is correct.
self.assertAllClose(pt_ss_top, ss[0][0])
Expand Down Expand Up @@ -315,7 +279,7 @@ def test_functor(ea, segments, num_segments, rot_key):

ss = tf_shell.to_tensorflow(ess, test_context.key)

pt_ss_top, pt_ss_bottom = self.plaintext_segment_sum(a, segments, num_segments)
pt_ss_top, pt_ss_bottom = tf_shell.segment_sum(a, segments, num_segments)

# Ensure the data is correctly reduced.
self.assertAllClose(pt_ss_top, ss[0][0])
Expand Down
27 changes: 0 additions & 27 deletions tf_shell/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,30 +224,3 @@ def uniform_for_n_muls(test_context, num_muls, shape=None, subsequent_adds=0):
rand = tf.cast(rand, test_context.plaintext_dtype)

return rand


# TensorFlow's roll has slightly different semantics than tf-shell's roll.
# Encrypted rotation affects top and bottom halves independently.
# This function emulates this in plaintext by splitting the tensor in half,
# rotating each half, and then concatenating them back together.
def plaintext_roll(t, shift):
top, bottom = tf.split(t, num_or_size_splits=2, axis=0)
top = tf.roll(top, shift, axis=0)
bottom = tf.roll(bottom, shift, axis=0)
rotated_tftensor = tf.concat([top, bottom], axis=0)
return rotated_tftensor


# TensorFlow's reduce_sum has slightly different semantics than tf-shell's
# reduce_sum. Encrypted reduce_sum affects top and bottom halves
# independently, as well as repeating the sum across the halves. This
# function emulates this in plaintext.
def plaintext_reduce_sum_axis_0(t):
half_slots = t.shape[0] // 2
bottom_answer = tf.math.reduce_sum(t[0:half_slots], axis=0, keepdims=True)
top_answer = tf.math.reduce_sum(t[half_slots:], axis=0, keepdims=True)

repeated_bottom_answer = tf.repeat(bottom_answer, repeats=half_slots, axis=0)
repeated_top_answer = tf.repeat(top_answer, repeats=half_slots, axis=0)

return tf.concat([repeated_bottom_answer, repeated_top_answer], 0)
6 changes: 5 additions & 1 deletion tf_shell_ml/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def backward(self, dy, rotation_key):

# Perform the multiplication for dy/dw.
d_w = tf_shell.matmul(
tf.transpose(x), dy, rotation_key, pt_ct_reduction=self.grad_reduction
tf.transpose(x),
dy,
rotation_key,
pt_ct_reduction=self.grad_reduction,
emulate_pt_ct=True,
)
d_ws.append(d_w)

Expand Down

0 comments on commit a682db2

Please sign in to comment.