diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index b223b03..53705bf 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -703,7 +703,7 @@ def to_tensorflow(s_tensor, key=None): ) -def roll(x, shift, rotation_key): +def roll(x, shift, rotation_key=None): if isinstance(x, ShellTensor64): if not isinstance(rotation_key, ShellRotationKey64): raise ValueError( @@ -734,7 +734,16 @@ def roll(x, shift, rotation_key): _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor): - return tf.roll(x, shift) + # TensorFlow's roll has slightly different semantics than tf-shell's + # roll. Encrypted rotation affects top and bottom halves independently. + # This function emulates this in plaintext by splitting the tensor in + # half, rotating each half, and then concatenating them back together. + top, bottom = tf.split(x, num_or_size_splits=2, axis=0) + top = tf.roll(top, shift, axis=0) + bottom = tf.roll(bottom, shift, axis=0) + rotated_tftensor = tf.concat([top, bottom], axis=0) + return rotated_tftensor + else: raise ValueError(f"Unsupported type for roll. Got {type(x)}.") @@ -786,7 +795,22 @@ def reduce_sum(x, axis, rotation_key=None): _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor): - return tf.reduce_sum(x, axis) + if axis == 0: + # TensorFlow's reduce_sum over axis 0 (the slotting dimension) has + # slightly different semantics than tf-shell's reduce_sum. Encrypted + # reduce_sum affects top and bottom halves independently, as well as + # repeating the sum across the halves. This emulates this in + # plaintext. + half_slots = x.shape[0] // 2 + bottom_answer = tf.math.reduce_sum(x[0:half_slots], axis=0, keepdims=True) + top_answer = tf.math.reduce_sum(x[half_slots:], axis=0, keepdims=True) + + repeated_bottom_answer = tf.repeat(bottom_answer, repeats=half_slots, axis=0) + repeated_top_answer = tf.repeat(top_answer, repeats=half_slots, axis=0) + + return tf.concat([repeated_bottom_answer, repeated_top_answer], 0) + else: + return tf.reduce_sum(x, axis) else: raise ValueError(f"Unsupported type for reduce_sum. Got {type(x)}.") @@ -842,7 +866,7 @@ def fast_reduce_sum(x): ) -def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"): +def matmul(x, y, rotation_key=None, pt_ct_reduction="galois", emulate_pt_ct=False): """Matrix multiplication is specialized to whether the operands are plaintext or ciphertext. @@ -935,7 +959,26 @@ def matmul(x, y, rotation_key=None, pt_ct_reduction="galois"): return NotImplementedError elif isinstance(x, tf.Tensor) and isinstance(y, tf.Tensor): - return tf.matmul(x, y) + if emulate_pt_ct: + # tf-shell matmult has slightly different semantics than plaintext / + # Tensorflow. Encrypted matmult affects top and bottom halves + # independently, as well as the first dimension repeating the sum of + # either the halves. This function emulates this in plaintext with + # element-wise multiplication, and an optional reduction. + shape_range = range(len(x.shape)) + x = tf.transpose(x, perm=[shape_range[-1]] + list(shape_range[:-1])) + x = tf.expand_dims(x, axis=-1) + for _ in range(len(x.shape) - 2): + y = tf.expand_dims(y, axis=-2) + res = x * y + + if pt_ct_reduction != "none": + res = reduce_sum(res, axis=0) + + return res + + else: + return tf.matmul(x, y) else: raise ValueError( @@ -1098,7 +1141,29 @@ def segment_sum(x, segments, num_segments, rotation_key=None, reduction="galois" reduction_count, ) elif isinstance(x, tf.Tensor): - return tf.math.unsorted_segment_sum(x, segments, num_segments) + # tf-shell segment functions differs from tensorflow in the following + # ways: First, the ciphertext dimension is included in the output, but + # only one dimension is valid. For the top half of the ciphertext, the + # first dimension is valid, and for the bottom half, the `num_slots // + # 2`th dimension is valid. + # Second, the reduction only happens across half of the batching + # dimension, due to how rotations in tf-shell work. Segment reduction + # happens on the top and bottom halves of the ciphertext independently. + if reduction == "none": + raise ValueError( + "Plaintext segment_sum does not support `none` reduction." + ) + half_slots = x.shape[0] // 2 + padding = tf.zeros_like(x[:half_slots]) + + x_top = tf.concat([x[:half_slots], padding], 0) + x_bottom = tf.concat([padding, x[half_slots:]], 0) + + top_answer = tf.math.unsorted_segment_sum(x_top, segments, num_segments) + bottom_answer = tf.math.unsorted_segment_sum(x_bottom, segments, num_segments) + + return top_answer, bottom_answer + else: raise ValueError("Unsupported type for segment_sum") diff --git a/tf_shell/test/distribution_test.py b/tf_shell/test/distribution_test.py index ec627df..9377273 100644 --- a/tf_shell/test/distribution_test.py +++ b/tf_shell/test/distribution_test.py @@ -98,7 +98,7 @@ def test_distribution(self): self.assertAllClose(c, a + b, atol=1) self.assertAllClose(e, a + d, atol=1) self.assertAllClose( - f, test_utils.plaintext_reduce_sum_axis_0(a + d), atol=1, rtol=1e-2 + f, tf_shell.reduce_sum(a + d, axis=0), atol=1, rtol=1e-2 ) diff --git a/tf_shell/test/mat_mul_test.py b/tf_shell/test/mat_mul_test.py index bfe522f..7ae5fca 100644 --- a/tf_shell/test/mat_mul_test.py +++ b/tf_shell/test/mat_mul_test.py @@ -117,39 +117,6 @@ def test_ct_tf_matmul(self): tf.config.run_functions_eagerly(eager) self._test_ct_tf_matmul(test_context) - # tf-shell matmult has slightly different semantics than plaintext / - # Tensorflow. Encrypted matmult affects top and bottom halves independently, - # as well as the first dimension repeating the sum of either the halves. - # This function emulates this in plaintext. - def plaintext_matmul(self, a, b): - half_slots = b.shape[0] // 2 - tf_half_slots = tf.constant([half_slots], dtype=tf.int32) - - a_shape = tf.shape(a) - a_top_start = tf.zeros_like(a_shape) - a_top_shape = tf.concat([a_shape[:-1], tf_half_slots], axis=0) - a_top = tf.slice(a, a_top_start, a_top_shape) - a_bottom_start = tf.concat( - [tf.zeros_like(a_top_start[:-1]), tf_half_slots], axis=0 - ) - a_bottom_shape = tf.concat([a_shape[:-1], tf_half_slots], axis=0) - a_bottom = tf.slice(a, a_bottom_start, a_bottom_shape) - - assert len(tf.shape(b)) == 2 - b_top = b[:half_slots, :] - b_bottom = b[half_slots:, :] - - top = tf.matmul(a_top, b_top) - bottom = tf.matmul(a_bottom, b_bottom) - - top = tf.expand_dims(top, axis=0) - bottom = tf.expand_dims(bottom, axis=0) - - top = tf.repeat(top, repeats=[half_slots], axis=0) - bottom = tf.repeat(bottom, repeats=[half_slots], axis=0) - - return tf.concat([top, bottom], axis=0) - def _test_tf_ct_matmul(self, test_context, use_fast_rotation): # Generating the following tensors should always succeed since this test # uses it's own special context. @@ -174,7 +141,7 @@ def _test_tf_ct_matmul(self, test_context, use_fast_rotation): return eb = tf_shell.to_encrypted(b, test_context.key, test_context.shell_context) - check_c = self.plaintext_matmul(a, b) + check_c = tf_shell.matmul(a, b, emulate_pt_ct=True) @tf.function def test_functor(): diff --git a/tf_shell/test/rotation_test.py b/tf_shell/test/rotation_test.py index b9d5ccc..8b4c881 100644 --- a/tf_shell/test/rotation_test.py +++ b/tf_shell/test/rotation_test.py @@ -93,7 +93,7 @@ def _test_roll(self, test_context, roll_num): tftensor, multiples=[1] * (i + 1) + [test_context.outer_shape[i]] ) - rolled_tftensor = test_utils.plaintext_roll(tftensor, roll_num) + rolled_tftensor = tf_shell.roll(tftensor, roll_num) s = tf_shell.to_shell_plaintext(tftensor, test_context.shell_context) enc = tf_shell.to_encrypted(s, test_context.key) @@ -148,7 +148,7 @@ def _test_roll_mod_reduced(self, test_context, roll_num): tftensor, multiples=[1] * (i + 1) + [test_context.outer_shape[i]] ) - rolled_tftensor = test_utils.plaintext_roll(tftensor, roll_num) + rolled_tftensor = tf_shell.roll(tftensor, roll_num) s = tf_shell.to_shell_plaintext(tftensor, test_context.shell_context) enc = tf_shell.to_encrypted(s, test_context.key) @@ -205,7 +205,7 @@ def _test_reduce_sum_axis_0(self, test_context): tftensor_out = tf_shell.to_tensorflow(enc_reduce_sum, test_context.key) self.assertAllClose( - tftensor_out, test_utils.plaintext_reduce_sum_axis_0(tftensor), atol=1e-3 + tftensor_out, tf_shell.reduce_sum(tftensor, axis=0), atol=1e-3 ) def test_reduce_sum_axis_0(self): diff --git a/tf_shell/test/rotation_test_fast.py b/tf_shell/test/rotation_test_fast.py index 69242a3..7618815 100644 --- a/tf_shell/test/rotation_test_fast.py +++ b/tf_shell/test/rotation_test_fast.py @@ -80,7 +80,7 @@ def _test_fast_reduce_sum_axis_0(self, test_context): enc_reduce_sum, key=test_context.fast_rotation_key ) self.assertAllClose( - tftensor_out, test_utils.plaintext_reduce_sum_axis_0(tftensor), atol=1e-3 + tftensor_out, tf_shell.reduce_sum(tftensor, axis=0), atol=1e-3 ) def test_fast_reduce_sum_axis_0(self): diff --git a/tf_shell/test/segment_test.py b/tf_shell/test/segment_test.py index 6650e38..d1db3f9 100644 --- a/tf_shell/test/segment_test.py +++ b/tf_shell/test/segment_test.py @@ -44,42 +44,6 @@ def setUpClass(cls): def tearDownClass(cls): cls.rotation_test_contexts = None - # tf-shell segment functions differs from tensorflow in the following ways: - # First, the ciphertext dimension is included in the output, but only one - # dimension is valid. For the top half of the ciphertext, the first - # dimension is valid, and for the bottom half, the `num_slots // 2`th - # dimension is valid. - # Second, the reduction only happens across half of the batching dimension, - # due to how rotations in tf-shell work. Segment reduction happens on the - # top and bottom halves of the ciphertext independently. - def plaintext_segment_sum(self, x, segments, num_segments, start_segment=0): - half_slots = x.shape[0] // 2 - padding = tf.zeros_like(x[:half_slots]) - - x_top = tf.concat([x[:half_slots], padding], 0) - x_bottom = tf.concat([padding, x[half_slots:]], 0) - - top_answer = tf.math.unsorted_segment_sum(x_top, segments, num_segments) - bottom_answer = tf.math.unsorted_segment_sum(x_bottom, segments, num_segments) - - if start_segment > 0: - top_answer = tf.concat( - [ - tf.zeros_like(top_answer[:start_segment]), - top_answer[start_segment:], - ], - axis=0, - ) - bottom_answer = tf.concat( - [ - tf.zeros_like(bottom_answer[:start_segment]), - bottom_answer[start_segment:], - ], - axis=0, - ) - - return top_answer, bottom_answer - def create_rand_data(self, test_context, repeats): try: shape_prod = math.prod(test_context.outer_shape) @@ -163,7 +127,7 @@ def test_functor(ea, segments, num_segments, rot_key): ss = tf_shell.to_tensorflow(ess, test_context.key) - pt_ss_top, pt_ss_bottom = self.plaintext_segment_sum(a, segments, num_segments) + pt_ss_top, pt_ss_bottom = tf_shell.segment_sum(a, segments, num_segments) # Ensure the reduced data is correct. self.assertAllClose(pt_ss_top, ss[0][0]) @@ -315,7 +279,7 @@ def test_functor(ea, segments, num_segments, rot_key): ss = tf_shell.to_tensorflow(ess, test_context.key) - pt_ss_top, pt_ss_bottom = self.plaintext_segment_sum(a, segments, num_segments) + pt_ss_top, pt_ss_bottom = tf_shell.segment_sum(a, segments, num_segments) # Ensure the data is correctly reduced. self.assertAllClose(pt_ss_top, ss[0][0]) diff --git a/tf_shell/test/test_utils.py b/tf_shell/test/test_utils.py index dd74b09..59bb291 100644 --- a/tf_shell/test/test_utils.py +++ b/tf_shell/test/test_utils.py @@ -224,30 +224,3 @@ def uniform_for_n_muls(test_context, num_muls, shape=None, subsequent_adds=0): rand = tf.cast(rand, test_context.plaintext_dtype) return rand - - -# TensorFlow's roll has slightly different semantics than tf-shell's roll. -# Encrypted rotation affects top and bottom halves independently. -# This function emulates this in plaintext by splitting the tensor in half, -# rotating each half, and then concatenating them back together. -def plaintext_roll(t, shift): - top, bottom = tf.split(t, num_or_size_splits=2, axis=0) - top = tf.roll(top, shift, axis=0) - bottom = tf.roll(bottom, shift, axis=0) - rotated_tftensor = tf.concat([top, bottom], axis=0) - return rotated_tftensor - - -# TensorFlow's reduce_sum has slightly different semantics than tf-shell's -# reduce_sum. Encrypted reduce_sum affects top and bottom halves -# independently, as well as repeating the sum across the halves. This -# function emulates this in plaintext. -def plaintext_reduce_sum_axis_0(t): - half_slots = t.shape[0] // 2 - bottom_answer = tf.math.reduce_sum(t[0:half_slots], axis=0, keepdims=True) - top_answer = tf.math.reduce_sum(t[half_slots:], axis=0, keepdims=True) - - repeated_bottom_answer = tf.repeat(bottom_answer, repeats=half_slots, axis=0) - repeated_top_answer = tf.repeat(top_answer, repeats=half_slots, axis=0) - - return tf.concat([repeated_bottom_answer, repeated_top_answer], 0) diff --git a/tf_shell_ml/dense.py b/tf_shell_ml/dense.py index 7f556cf..8f38c37 100644 --- a/tf_shell_ml/dense.py +++ b/tf_shell_ml/dense.py @@ -120,7 +120,11 @@ def backward(self, dy, rotation_key): # Perform the multiplication for dy/dw. d_w = tf_shell.matmul( - tf.transpose(x), dy, rotation_key, pt_ct_reduction=self.grad_reduction + tf.transpose(x), + dy, + rotation_key, + pt_ct_reduction=self.grad_reduction, + emulate_pt_ct=True, ) d_ws.append(d_w)