diff --git a/shell_tensor/cc/kernels/rotation_kernels.cc b/shell_tensor/cc/kernels/rotation_kernels.cc index ba8b508..885cd82 100644 --- a/shell_tensor/cc/kernels/rotation_kernels.cc +++ b/shell_tensor/cc/kernels/rotation_kernels.cc @@ -42,7 +42,7 @@ using tensorflow::uint8; using tensorflow::Variant; using tensorflow::errors::InvalidArgument; -constexpr int kLogGadgetBase = 10; +constexpr int kLogGadgetBase = 4; constexpr rlwe::PrngType kPrngType = rlwe::PRNG_TYPE_HKDF; template @@ -132,6 +132,9 @@ class RollOp : public OpKernel { void Compute(OpKernelContext* op_ctx) override { OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, GetVariant>(op_ctx, 0)); + OP_REQUIRES( + op_ctx, rotation_key_var != nullptr, + InvalidArgument("RotationKeyVariant a did not unwrap successfully.")); std::map const* keys = &rotation_key_var->keys; Tensor const& value = op_ctx->input(1); @@ -167,9 +170,12 @@ class RollOp : public OpKernel { shift += num_slots / 2; } - OP_REQUIRES(op_ctx, keys->find(shift) != keys->end(), - InvalidArgument("No key for shift of '", shift, "'")); - PowerAndKey const& p_and_k = keys->at(shift); + PowerAndKey const* p_and_k; + if (shift != 0) { + OP_REQUIRES(op_ctx, keys->find(shift) != keys->end(), + InvalidArgument("No key for shift of '", shift, "'")); + p_and_k = &keys->at(shift); + } for (int i = 0; i < flat_output.dimension(0); ++i) { SymmetricCtVariant const* ct_var = @@ -184,8 +190,8 @@ class RollOp : public OpKernel { flat_output(i) = std::move(ct_out_var); } else { OP_REQUIRES_VALUE(auto ct_sub, op_ctx, - ct.Substitute(p_and_k.substitution_power)); - OP_REQUIRES_VALUE(auto ct_rot, op_ctx, p_and_k.key.ApplyTo(ct_sub)); + ct.Substitute(p_and_k->substitution_power)); + OP_REQUIRES_VALUE(auto ct_rot, op_ctx, p_and_k->key.ApplyTo(ct_sub)); SymmetricCtVariant ct_out_var(std::move(ct_rot)); flat_output(i) = std::move(ct_out_var); @@ -195,7 +201,7 @@ class RollOp : public OpKernel { }; template -class ReduceSumOp : public OpKernel { +class ReduceSumByRotationOp : public OpKernel { private: using ModularInt = rlwe::MontgomeryInt; using RotationKey = rlwe::RnsGaloisKey; @@ -203,28 +209,32 @@ class ReduceSumOp : public OpKernel { using PowerAndKey = typename RotationKeyVariant::PowerAndKey; public: - explicit ReduceSumOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} + explicit ReduceSumByRotationOp(OpKernelConstruction* op_ctx) + : OpKernel(op_ctx) {} void Compute(OpKernelContext* op_ctx) override { - // Recover the input rotation keys. - OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, - GetVariant>(op_ctx, 0)); - std::map const* keys = &rotation_key_var->keys; - // Recover the input tensor. - Tensor const& value = op_ctx->input(1); + Tensor const& value = op_ctx->input(0); OP_REQUIRES(op_ctx, value.dim_size(0) > 0, InvalidArgument("Cannot reduce_sum an empty ciphertext.")); + auto flat_value = value.flat(); - // Setup the output. + // Recover the input rotation keys. + OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, + GetVariant>(op_ctx, 1)); + OP_REQUIRES( + op_ctx, rotation_key_var != nullptr, + InvalidArgument("RotationKeyVariant a did not unwrap successfully.")); + std::map const* keys = &rotation_key_var->keys; + Tensor* output; OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, value.shape(), &output)); auto flat_output = output->flat(); for (int i = 0; i < flat_output.dimension(0); ++i) { - // Learn how many slots there are from first ciphertext and create a deep - // copy to hold the sum. + // Learn how many slots there are from first ciphertext and create a + // deep copy to hold the sum. SymmetricCtVariant const* ct_var = std::move(flat_value(i).get>()); OP_REQUIRES( @@ -235,7 +245,7 @@ class ReduceSumOp : public OpKernel { // Add the rotations to the sum. // Note the ciphertext rotations operate on each half of the ciphertext - // separately. So the max rotatation is by half the number of slots. + // separately. So the max rotation is by half the number of slots. for (int shift = 1; shift < num_slots / 2; shift <<= 1) { // TODO if debug OP_REQUIRES(op_ctx, keys->find(shift) != keys->end(), @@ -248,7 +258,7 @@ class ReduceSumOp : public OpKernel { OP_REQUIRES_VALUE(auto ct_rot, op_ctx, p_and_k.key.ApplyTo(ct_sub)); // Add to the sum. - sum.AddInPlace(ct_rot); + OP_REQUIRES_OK(op_ctx, sum.AddInPlace(ct_rot)); } SymmetricCtVariant ct_out_var(std::move(sum)); @@ -257,10 +267,104 @@ class ReduceSumOp : public OpKernel { } }; +template +class ReduceSumOp : public OpKernel { + private: + using ModularInt = rlwe::MontgomeryInt; + using RotationKey = rlwe::RnsGaloisKey; + using SymmetricCt = rlwe::RnsBgvCiphertext; + using PowerAndKey = typename RotationKeyVariant::PowerAndKey; + + public: + explicit ReduceSumOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} + + void Compute(OpKernelContext* op_ctx) override { + // Recover the input tensor. + Tensor const& value = op_ctx->input(0); + OP_REQUIRES(op_ctx, value.dim_size(0) > 0, + InvalidArgument("Cannot reduce_sum an empty ciphertext.")); + + // Recover the axis to reduce over. + Tensor const& axis_tensor = op_ctx->input(1); + OP_REQUIRES(op_ctx, axis_tensor.NumElements() == 1, + InvalidArgument("axis must be scalar, saw shape: ", + axis_tensor.shape().DebugString())); + OP_REQUIRES_VALUE(int64 axis, op_ctx, GetScalar(op_ctx, 1)); + + // The axis to reduce over. + int dim_to_reduce = axis - 1; + + // Check axis is within dim size. + OP_REQUIRES(op_ctx, dim_to_reduce < value.dims(), + InvalidArgument("Cannot reduce_sum over polynomial_axis '", + dim_to_reduce, "' (axis '", axis, + "') for input with shape ", + value.shape().DebugString())); + + uint8_t dim_sz_to_reduce = value.dim_size(dim_to_reduce); + + // Create a temp Tensor to hold intermediate sums during the reduction. + // It is the same size as the input Tensor. + Tensor intermediate_sums; + auto intermediate_sums_shape = value.shape(); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_temp(tensorflow::DT_VARIANT, + intermediate_sums_shape, + &intermediate_sums)); + + auto flat_intermediate_sums = + intermediate_sums.flat_inner_outer_dims(dim_to_reduce - 1); + auto flat_value = value.flat_inner_outer_dims(dim_to_reduce - 1); + + // Setup the output. + Tensor* output; + auto output_shape = value.shape(); + OP_REQUIRES_OK(op_ctx, output_shape.RemoveDimWithStatus(dim_to_reduce)); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); + // Setup a shape to access the output Tensor as a flat Tensor, with the + // same indexing as the input Tensor excluding the dimension to reduce. + int inner_shape = flat_value.dimension(0); + int outer_shape = flat_value.dimension(2); + auto flat_output = output->shaped({inner_shape, outer_shape}); + + // Take the first ciphertext in the chip and add all the other chips to it. + for (int i = 0; i < flat_output.dimension(0); ++i) { + for (int j = 0; j < flat_output.dimension(1); ++j) { + // Get the first chip. + SymmetricCtVariant const* first_ct_var = + std::move(flat_value(i, 0, j).get>()); + OP_REQUIRES(op_ctx, first_ct_var != nullptr, + InvalidArgument( + "SymmetricCtVariant a did not unwrap successfully.")); + SymmetricCt sum = first_ct_var->ct; // deep copy to start the sum. + + // Add the remaining chips. + for (int chip_dim = 1; chip_dim < dim_sz_to_reduce; ++chip_dim) { + SymmetricCtVariant const* ct_var = std::move( + flat_value(i, chip_dim, j).get>()); + OP_REQUIRES(op_ctx, ct_var != nullptr, + InvalidArgument( + "SymmetricCtVariant a did not unwrap successfully.")); + SymmetricCt const& ct = ct_var->ct; + + // Perform the addition. + OP_REQUIRES_OK(op_ctx, sum.AddInPlace(ct)); + } + + // Store in the output. + SymmetricCtVariant res_var(std::move(sum)); + flat_output(i, j) = std::move(res_var); + } + } + } +}; + REGISTER_KERNEL_BUILDER(Name("RotationKeyGen64").Device(DEVICE_CPU), RotationKeyGenOp); REGISTER_KERNEL_BUILDER(Name("Roll64").Device(DEVICE_CPU), RollOp); +REGISTER_KERNEL_BUILDER(Name("ReduceSumByRotation64").Device(DEVICE_CPU), + ReduceSumByRotationOp); + REGISTER_KERNEL_BUILDER(Name("ReduceSum64").Device(DEVICE_CPU), ReduceSumOp); \ No newline at end of file diff --git a/shell_tensor/cc/ops/shell_ops.cc b/shell_tensor/cc/ops/shell_ops.cc index cdc2f90..3272528 100644 --- a/shell_tensor/cc/ops/shell_ops.cc +++ b/shell_tensor/cc/ops/shell_ops.cc @@ -207,8 +207,14 @@ REGISTER_OP("Roll64") .Output("rotated_value: variant") .SetIsStateful(); -REGISTER_OP("ReduceSum64") +REGISTER_OP("ReduceSumByRotation64") + .Input("value: variant") .Input("rotation_key: variant") + .Output("repeated_reduce_sum: variant") + .SetIsStateful(); + +REGISTER_OP("ReduceSum64") .Input("value: variant") + .Input("axis: int64") .Output("repeated_reduce_sum: variant") .SetIsStateful(); \ No newline at end of file diff --git a/shell_tensor/python/ops/shell_ops.py b/shell_tensor/python/ops/shell_ops.py index 9302b4d..6704243 100644 --- a/shell_tensor/python/ops/shell_ops.py +++ b/shell_tensor/python/ops/shell_ops.py @@ -53,4 +53,5 @@ # Rotate slots. rotation_key_gen64 = shell_ops.rotation_key_gen64 roll64 = shell_ops.roll64 +reduce_sum_by_rotation64 = shell_ops.reduce_sum_by_rotation64 reduce_sum64 = shell_ops.reduce_sum64 diff --git a/shell_tensor/python/shell_tensor.py b/shell_tensor/python/shell_tensor.py index 3debeae..7cdd999 100644 --- a/shell_tensor/python/shell_tensor.py +++ b/shell_tensor/python/shell_tensor.py @@ -355,6 +355,8 @@ def roll(self, rotation_key, num_slots): if not self._is_enc: raise ValueError("Unencrypted ShellTensor rotation not supported yet.") else: + num_slots = tf.cast(num_slots, tf.int64) + return ShellTensor64( value=shell_ops.roll64(rotation_key, self._raw, num_slots), context=self._context, @@ -365,12 +367,32 @@ def roll(self, rotation_key, num_slots): mult_count=self._mult_count, ) - def reduce_sum(self, rotation_key): + def reduce_sum(self, axis=0, rotation_key=None): if not self._is_enc: raise ValueError("Unencrypted ShellTensor reduce_sum not supported yet.") + # Check axis is a scalar + elif isinstance(axis, tf.Tensor) and not axis.shape != []: + raise ValueError("Only scalar `axis` is supported.") + elif axis == 0: + if rotation_key is None: + raise ValueError( + "Rotation key must be provided to reduce_sum over axis 0." + ) + + return ShellTensor64( + value=shell_ops.reduce_sum_by_rotation64(self._raw, rotation_key), + context=self._context, + num_slots=self._num_slots, + underlying_dtype=self._underlying_dtype, + is_enc=True, + fxp_fractional_bits=self._fxp_fractional_bits, + mult_count=self._mult_count, + ) else: + if axis >= len(self.shape): + raise ValueError("Axis greater than number of dimensions") return ShellTensor64( - value=shell_ops.reduce_sum64(rotation_key, self._raw), + value=shell_ops.reduce_sum64(self._raw, axis), context=self._context, num_slots=self._num_slots, underlying_dtype=self._underlying_dtype, diff --git a/shell_tensor/test/BUILD b/shell_tensor/test/BUILD index 03cc908..74e6e4b 100644 --- a/shell_tensor/test/BUILD +++ b/shell_tensor/test/BUILD @@ -59,6 +59,7 @@ py_test( py_test( name = "rotation_test", + size = "medium", srcs = [ "rotation_test.py", "test_utils.py", diff --git a/shell_tensor/test/rotation_test.py b/shell_tensor/test/rotation_test.py index e260805..a368514 100644 --- a/shell_tensor/test/rotation_test.py +++ b/shell_tensor/test/rotation_test.py @@ -16,22 +16,19 @@ import tensorflow as tf import shell_tensor import test_utils +from multiprocessing import Pool +from itertools import repeat class TestShellTensorRotation(tf.test.TestCase): - # plaintext_dtype = tf.int32 - # log_slots = 11 - # slots = 2**log_slots - - # def get_context(self): - # return shell_tensor.create_context64( - # log_n=self.log_slots, - # main_moduli=[8556589057, 8388812801], - # aux_moduli=[34359709697], - # plaintext_modulus=40961, - # noise_variance=8, - # seed="", - # ) + rotation_dtypes = [ + tf.int32, + tf.int64, + tf.float32, + tf.float64, + ] + roll_test_outer_shape = [3, 3] + test_outer_shape = [2, 5, 4] def _test_keygen(self, test_context): context = test_context.shell_context @@ -55,78 +52,167 @@ def plaintext_roll(self, t, shift): rotated_tftensor = tf.concat([top, bottom], axis=0) return rotated_tftensor - def _test_rotate(self, test_context, plaintext_dtype): + def _test_roll(self, test_context, key, rotation_key, plaintext_dtype, roll_num): context = test_context.shell_context - key = shell_tensor.create_key64(context) - rotation_key = shell_tensor.create_rotation_key64(context, key) - shift_right = 1 - shift_left = -1 + # Create a tensor with the shape of slots x (outer_shape) where each + # column of the first dimensions counts from 0 to slots-1. + tftensor = tf.range(0, test_context.slots, delta=1, dtype=plaintext_dtype) + for i in range(len(self.roll_test_outer_shape)): + tftensor = tf.expand_dims(tftensor, axis=-1) + tftensor = tf.tile( + tftensor, multiples=[1] * (i + 1) + [self.roll_test_outer_shape[i]] + ) - tftensor = tf.range(test_context.slots, delta=1, dtype=None, name="range") - - tftensor_right = self.plaintext_roll(tftensor, shift_right) - tftensor_left = self.plaintext_roll(tftensor, shift_left) + rolled_tftensor = self.plaintext_roll(tftensor, roll_num) s = shell_tensor.to_shell_tensor(context, tftensor) enc = s.get_encrypted(key) - enc_right = enc.roll(rotation_key, shift_right) - tftensor_out = enc_right.get_decrypted(key) - self.assertAllClose(tftensor_out, tftensor_right) - - enc_left = enc.roll(rotation_key, shift_left) - tftensor_out = enc_left.get_decrypted(key) - self.assertAllClose(tftensor_out, tftensor_left) + rolled_enc = enc.roll(rotation_key, roll_num) + rolled_result = rolled_enc.get_decrypted(key) + self.assertAllClose(rolled_tftensor, rolled_result) - def test_rotate(self): + def test_roll(self): for test_context in test_utils.test_contexts: - for test_dtype in test_utils.test_dtypes: - with self.subTest("rotate with dtype %s" % (test_dtype)): - self._test_rotate(test_context, test_dtype) + context = test_context.shell_context + key = shell_tensor.create_key64(context) + rotation_key = shell_tensor.create_rotation_key64(context, key) + rotation_range = test_context.slots // 2 - 1 + + for test_dtype in self.rotation_dtypes: + for roll_num in range(-rotation_range, rotation_range, 1): + with self.subTest( + "rotate with dtype %s, rotating by %s" % (test_dtype, roll_num) + ): + self._test_roll( + test_context, key, rotation_key, test_dtype, roll_num + ) # TensorFlow's reduce_sum has slightly different semantics than encrypted # reduce_sum. Encrypted reduce_sum affects top and bottom halves # independently, as well as repeating the sum across the halves. This # function emulates this in plaintext. - def plaintext_reduce_sum(self, t): + def plaintext_reduce_sum_axis_0(self, t): half_slots = t.shape[0] // 2 - bottom_answer = tf.math.reduce_sum( - t[0 : half_slots], axis=0, keepdims=True - ) - top_answer = tf.math.reduce_sum(t[half_slots :], axis=0, keepdims=True) + bottom_answer = tf.math.reduce_sum(t[0:half_slots], axis=0, keepdims=True) + top_answer = tf.math.reduce_sum(t[half_slots:], axis=0, keepdims=True) - repeated_bottom_answer = tf.repeat( - bottom_answer, repeats=half_slots, axis=0 - ) + repeated_bottom_answer = tf.repeat(bottom_answer, repeats=half_slots, axis=0) repeated_top_answer = tf.repeat(top_answer, repeats=half_slots, axis=0) return tf.concat([repeated_bottom_answer, repeated_top_answer], 0) - def _test_reduce_sum(self, test_context, plaintext_dtype, frac_bits): + def _test_reduce_sum_axis_0( + self, test_context, key, rotation_key, plaintext_dtype, frac_bits + ): context = test_context.shell_context - key = shell_tensor.create_key64(context) - rotation_key = shell_tensor.create_rotation_key64(context, key) - test_shape = [test_context.slots, 2, 3, 4] - tftensor = tf.random.uniform(test_shape, dtype=tf.int32, maxval=10) + # reduce_sum across axis 0 requires adding over all the slots. + min_val, max_val = test_utils.get_bounds_for_n_adds( + plaintext_dtype, + test_context.plaintext_modulus, + frac_bits, + test_context.slots, + ) + + if max_val is 0: + # Test parameters do not support reduce_sum at this precision. + print( + "Note: Skipping test reduce_sum_axis0 with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + test_shape = self.test_outer_shape.copy() + test_shape.insert(0, test_context.slots) + + tftensor = tf.random.uniform( + test_shape, + dtype=tf.int64, + maxval=max_val, + minval=min_val, + ) + tftensor = tf.cast(tftensor, plaintext_dtype) s = shell_tensor.to_shell_tensor(context, tftensor) enc = s.get_encrypted(key) - enc_reduce_sum = enc.reduce_sum(rotation_key) + enc_reduce_sum = enc.reduce_sum(axis=0, rotation_key=rotation_key) tftensor_out = enc_reduce_sum.get_decrypted(key) - self.assertAllClose(tftensor_out, self.plaintext_reduce_sum(tftensor)) + self.assertAllClose(tftensor_out, self.plaintext_reduce_sum_axis_0(tftensor)) - def test_reduce_sum(self): + def test_reduce_sum_axis_0(self): for test_context in test_utils.test_contexts: + context = test_context.shell_context + key = shell_tensor.create_key64(context) + rotation_key = shell_tensor.create_rotation_key64(context, key) + for frac_bits in test_utils.test_fxp_fractional_bits: - for test_dtype in test_utils.test_dtypes: + for test_dtype in self.rotation_dtypes: with self.subTest( - "reduce_sum with fractional bits %d and dtype %s" + "reduce_sum_axis_0 with fractional bits %d and dtype %s" % (frac_bits, test_dtype) ): - self._test_reduce_sum(test_context, test_dtype, frac_bits) + self._test_reduce_sum_axis_0( + test_context, key, rotation_key, test_dtype, frac_bits + ) + + def _test_reduce_sum_axis_n( + self, test_context, plaintext_dtype, frac_bits, outer_axis + ): + context = test_context.shell_context + key = shell_tensor.create_key64(context) + + # reduce_sum across `axis` requires adding over that dimension. + min_val, max_val = test_utils.get_bounds_for_n_adds( + plaintext_dtype, + test_context.plaintext_modulus, + frac_bits, + self.test_outer_shape[outer_axis], + ) + + if max_val == 0: + # Test parameters do not support reduce_sum at this precision. + print( + "Note: Skipping test reduce_sum_axis0 with dtype %s and frac_bits %d. Not enough precision to support this test." + % (plaintext_dtype, frac_bits) + ) + return + + test_shape = self.test_outer_shape.copy() + test_shape.insert(0, test_context.slots) + + tftensor = tf.random.uniform( + test_shape, + dtype=tf.int64, + maxval=max_val, + minval=min_val, + ) + tftensor = tf.cast(tftensor, plaintext_dtype) + s = shell_tensor.to_shell_tensor(context, tftensor) + enc = s.get_encrypted(key) + + enc_reduce_sum = enc.reduce_sum(axis=outer_axis + 1) + + tftensor_out = enc_reduce_sum.get_decrypted(key) + self.assertAllClose(tftensor_out, tf.reduce_sum(tftensor, axis=outer_axis + 1)) + + def test_reduce_sum_axis_n(self): + for test_context in test_utils.test_contexts: + for frac_bits in test_utils.test_fxp_fractional_bits: + for test_dtype in self.rotation_dtypes: + for outer_axis in range(len(self.test_outer_shape)): + with self.subTest( + "reduce_sum_axis_n with fractional bits %d, dtype %s, and axis %d" + % (frac_bits, test_dtype, outer_axis + 1) + ): + self._test_reduce_sum_axis_n( + test_context, + test_dtype, + frac_bits, + outer_axis, + ) if __name__ == "__main__":