Skip to content

Commit

Permalink
Implement ciphertext reduce_sum for anything other than the first axis.
Browse files Browse the repository at this point in the history
First axis was done in a previous commit.
  • Loading branch information
james-choncholas committed Feb 2, 2024
1 parent 98f9d6b commit 6e4be58
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 74 deletions.
142 changes: 123 additions & 19 deletions shell_tensor/cc/kernels/rotation_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using tensorflow::uint8;
using tensorflow::Variant;
using tensorflow::errors::InvalidArgument;

constexpr int kLogGadgetBase = 10;
constexpr int kLogGadgetBase = 4;
constexpr rlwe::PrngType kPrngType = rlwe::PRNG_TYPE_HKDF;

template <typename T>
Expand Down Expand Up @@ -132,6 +132,9 @@ class RollOp : public OpKernel {
void Compute(OpKernelContext* op_ctx) override {
OP_REQUIRES_VALUE(RotationKeyVariant<T> const* rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 0));
OP_REQUIRES(
op_ctx, rotation_key_var != nullptr,
InvalidArgument("RotationKeyVariant a did not unwrap successfully."));
std::map<int, PowerAndKey> const* keys = &rotation_key_var->keys;

Tensor const& value = op_ctx->input(1);
Expand Down Expand Up @@ -167,9 +170,12 @@ class RollOp : public OpKernel {
shift += num_slots / 2;
}

OP_REQUIRES(op_ctx, keys->find(shift) != keys->end(),
InvalidArgument("No key for shift of '", shift, "'"));
PowerAndKey const& p_and_k = keys->at(shift);
PowerAndKey const* p_and_k;
if (shift != 0) {
OP_REQUIRES(op_ctx, keys->find(shift) != keys->end(),
InvalidArgument("No key for shift of '", shift, "'"));
p_and_k = &keys->at(shift);
}

for (int i = 0; i < flat_output.dimension(0); ++i) {
SymmetricCtVariant<T> const* ct_var =
Expand All @@ -184,8 +190,8 @@ class RollOp : public OpKernel {
flat_output(i) = std::move(ct_out_var);
} else {
OP_REQUIRES_VALUE(auto ct_sub, op_ctx,
ct.Substitute(p_and_k.substitution_power));
OP_REQUIRES_VALUE(auto ct_rot, op_ctx, p_and_k.key.ApplyTo(ct_sub));
ct.Substitute(p_and_k->substitution_power));
OP_REQUIRES_VALUE(auto ct_rot, op_ctx, p_and_k->key.ApplyTo(ct_sub));

SymmetricCtVariant ct_out_var(std::move(ct_rot));
flat_output(i) = std::move(ct_out_var);
Expand All @@ -195,36 +201,40 @@ class RollOp : public OpKernel {
};

template <typename T>
class ReduceSumOp : public OpKernel {
class ReduceSumByRotationOp : public OpKernel {
private:
using ModularInt = rlwe::MontgomeryInt<T>;
using RotationKey = rlwe::RnsGaloisKey<ModularInt>;
using SymmetricCt = rlwe::RnsBgvCiphertext<ModularInt>;
using PowerAndKey = typename RotationKeyVariant<T>::PowerAndKey;

public:
explicit ReduceSumOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {}
explicit ReduceSumByRotationOp(OpKernelConstruction* op_ctx)
: OpKernel(op_ctx) {}

void Compute(OpKernelContext* op_ctx) override {
// Recover the input rotation keys.
OP_REQUIRES_VALUE(RotationKeyVariant<T> const* rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 0));
std::map<int, PowerAndKey> const* keys = &rotation_key_var->keys;

// Recover the input tensor.
Tensor const& value = op_ctx->input(1);
Tensor const& value = op_ctx->input(0);
OP_REQUIRES(op_ctx, value.dim_size(0) > 0,
InvalidArgument("Cannot reduce_sum an empty ciphertext."));

auto flat_value = value.flat<Variant>();

// Setup the output.
// Recover the input rotation keys.
OP_REQUIRES_VALUE(RotationKeyVariant<T> const* rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 1));
OP_REQUIRES(
op_ctx, rotation_key_var != nullptr,
InvalidArgument("RotationKeyVariant a did not unwrap successfully."));
std::map<int, PowerAndKey> const* keys = &rotation_key_var->keys;

Tensor* output;
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, value.shape(), &output));
auto flat_output = output->flat<Variant>();

for (int i = 0; i < flat_output.dimension(0); ++i) {
// Learn how many slots there are from first ciphertext and create a deep
// copy to hold the sum.
// Learn how many slots there are from first ciphertext and create a
// deep copy to hold the sum.
SymmetricCtVariant<T> const* ct_var =
std::move(flat_value(i).get<SymmetricCtVariant<T>>());
OP_REQUIRES(
Expand All @@ -235,7 +245,7 @@ class ReduceSumOp : public OpKernel {

// Add the rotations to the sum.
// Note the ciphertext rotations operate on each half of the ciphertext
// separately. So the max rotatation is by half the number of slots.
// separately. So the max rotation is by half the number of slots.
for (int shift = 1; shift < num_slots / 2; shift <<= 1) {
// TODO if debug
OP_REQUIRES(op_ctx, keys->find(shift) != keys->end(),
Expand All @@ -248,7 +258,7 @@ class ReduceSumOp : public OpKernel {
OP_REQUIRES_VALUE(auto ct_rot, op_ctx, p_and_k.key.ApplyTo(ct_sub));

// Add to the sum.
sum.AddInPlace(ct_rot);
OP_REQUIRES_OK(op_ctx, sum.AddInPlace(ct_rot));
}

SymmetricCtVariant ct_out_var(std::move(sum));
Expand All @@ -257,10 +267,104 @@ class ReduceSumOp : public OpKernel {
}
};

template <typename T>
class ReduceSumOp : public OpKernel {
private:
using ModularInt = rlwe::MontgomeryInt<T>;
using RotationKey = rlwe::RnsGaloisKey<ModularInt>;
using SymmetricCt = rlwe::RnsBgvCiphertext<ModularInt>;
using PowerAndKey = typename RotationKeyVariant<T>::PowerAndKey;

public:
explicit ReduceSumOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {}

void Compute(OpKernelContext* op_ctx) override {
// Recover the input tensor.
Tensor const& value = op_ctx->input(0);
OP_REQUIRES(op_ctx, value.dim_size(0) > 0,
InvalidArgument("Cannot reduce_sum an empty ciphertext."));

// Recover the axis to reduce over.
Tensor const& axis_tensor = op_ctx->input(1);
OP_REQUIRES(op_ctx, axis_tensor.NumElements() == 1,
InvalidArgument("axis must be scalar, saw shape: ",
axis_tensor.shape().DebugString()));
OP_REQUIRES_VALUE(int64 axis, op_ctx, GetScalar<int64>(op_ctx, 1));

// The axis to reduce over.
int dim_to_reduce = axis - 1;

// Check axis is within dim size.
OP_REQUIRES(op_ctx, dim_to_reduce < value.dims(),
InvalidArgument("Cannot reduce_sum over polynomial_axis '",
dim_to_reduce, "' (axis '", axis,
"') for input with shape ",
value.shape().DebugString()));

uint8_t dim_sz_to_reduce = value.dim_size(dim_to_reduce);

// Create a temp Tensor to hold intermediate sums during the reduction.
// It is the same size as the input Tensor.
Tensor intermediate_sums;
auto intermediate_sums_shape = value.shape();
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_temp(tensorflow::DT_VARIANT,
intermediate_sums_shape,
&intermediate_sums));

auto flat_intermediate_sums =
intermediate_sums.flat_inner_outer_dims<Variant>(dim_to_reduce - 1);
auto flat_value = value.flat_inner_outer_dims<Variant>(dim_to_reduce - 1);

// Setup the output.
Tensor* output;
auto output_shape = value.shape();
OP_REQUIRES_OK(op_ctx, output_shape.RemoveDimWithStatus(dim_to_reduce));
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output));
// Setup a shape to access the output Tensor as a flat Tensor, with the
// same indexing as the input Tensor excluding the dimension to reduce.
int inner_shape = flat_value.dimension(0);
int outer_shape = flat_value.dimension(2);
auto flat_output = output->shaped<Variant, 2>({inner_shape, outer_shape});

// Take the first ciphertext in the chip and add all the other chips to it.
for (int i = 0; i < flat_output.dimension(0); ++i) {
for (int j = 0; j < flat_output.dimension(1); ++j) {
// Get the first chip.
SymmetricCtVariant<T> const* first_ct_var =
std::move(flat_value(i, 0, j).get<SymmetricCtVariant<T>>());
OP_REQUIRES(op_ctx, first_ct_var != nullptr,
InvalidArgument(
"SymmetricCtVariant a did not unwrap successfully."));
SymmetricCt sum = first_ct_var->ct; // deep copy to start the sum.

// Add the remaining chips.
for (int chip_dim = 1; chip_dim < dim_sz_to_reduce; ++chip_dim) {
SymmetricCtVariant<T> const* ct_var = std::move(
flat_value(i, chip_dim, j).get<SymmetricCtVariant<T>>());
OP_REQUIRES(op_ctx, ct_var != nullptr,
InvalidArgument(
"SymmetricCtVariant a did not unwrap successfully."));
SymmetricCt const& ct = ct_var->ct;

// Perform the addition.
OP_REQUIRES_OK(op_ctx, sum.AddInPlace(ct));
}

// Store in the output.
SymmetricCtVariant res_var(std::move(sum));
flat_output(i, j) = std::move(res_var);
}
}
}
};

REGISTER_KERNEL_BUILDER(Name("RotationKeyGen64").Device(DEVICE_CPU),
RotationKeyGenOp<uint64>);

REGISTER_KERNEL_BUILDER(Name("Roll64").Device(DEVICE_CPU), RollOp<uint64>);

REGISTER_KERNEL_BUILDER(Name("ReduceSumByRotation64").Device(DEVICE_CPU),
ReduceSumByRotationOp<uint64>);

REGISTER_KERNEL_BUILDER(Name("ReduceSum64").Device(DEVICE_CPU),
ReduceSumOp<uint64>);
8 changes: 7 additions & 1 deletion shell_tensor/cc/ops/shell_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,14 @@ REGISTER_OP("Roll64")
.Output("rotated_value: variant")
.SetIsStateful();

REGISTER_OP("ReduceSum64")
REGISTER_OP("ReduceSumByRotation64")
.Input("value: variant")
.Input("rotation_key: variant")
.Output("repeated_reduce_sum: variant")
.SetIsStateful();

REGISTER_OP("ReduceSum64")
.Input("value: variant")
.Input("axis: int64")
.Output("repeated_reduce_sum: variant")
.SetIsStateful();
1 change: 1 addition & 0 deletions shell_tensor/python/ops/shell_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@
# Rotate slots.
rotation_key_gen64 = shell_ops.rotation_key_gen64
roll64 = shell_ops.roll64
reduce_sum_by_rotation64 = shell_ops.reduce_sum_by_rotation64
reduce_sum64 = shell_ops.reduce_sum64
26 changes: 24 additions & 2 deletions shell_tensor/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def roll(self, rotation_key, num_slots):
if not self._is_enc:
raise ValueError("Unencrypted ShellTensor rotation not supported yet.")
else:
num_slots = tf.cast(num_slots, tf.int64)

return ShellTensor64(
value=shell_ops.roll64(rotation_key, self._raw, num_slots),
context=self._context,
Expand All @@ -365,12 +367,32 @@ def roll(self, rotation_key, num_slots):
mult_count=self._mult_count,
)

def reduce_sum(self, rotation_key):
def reduce_sum(self, axis=0, rotation_key=None):
if not self._is_enc:
raise ValueError("Unencrypted ShellTensor reduce_sum not supported yet.")
# Check axis is a scalar
elif isinstance(axis, tf.Tensor) and not axis.shape != []:
raise ValueError("Only scalar `axis` is supported.")
elif axis == 0:
if rotation_key is None:
raise ValueError(
"Rotation key must be provided to reduce_sum over axis 0."
)

return ShellTensor64(
value=shell_ops.reduce_sum_by_rotation64(self._raw, rotation_key),
context=self._context,
num_slots=self._num_slots,
underlying_dtype=self._underlying_dtype,
is_enc=True,
fxp_fractional_bits=self._fxp_fractional_bits,
mult_count=self._mult_count,
)
else:
if axis >= len(self.shape):
raise ValueError("Axis greater than number of dimensions")
return ShellTensor64(
value=shell_ops.reduce_sum64(rotation_key, self._raw),
value=shell_ops.reduce_sum64(self._raw, axis),
context=self._context,
num_slots=self._num_slots,
underlying_dtype=self._underlying_dtype,
Expand Down
1 change: 1 addition & 0 deletions shell_tensor/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ py_test(

py_test(
name = "rotation_test",
size = "medium",
srcs = [
"rotation_test.py",
"test_utils.py",
Expand Down
Loading

0 comments on commit 6e4be58

Please sign in to comment.