Skip to content

Commit

Permalink
Reduce sum is avoided in dpsgd training (and ops like matmul and segm…
Browse files Browse the repository at this point in the history
…ent sum).
  • Loading branch information
james-choncholas committed Oct 26, 2024
1 parent 55508ed commit e9d7a83
Show file tree
Hide file tree
Showing 20 changed files with 348 additions and 279 deletions.
108 changes: 46 additions & 62 deletions tf_shell/cc/kernels/mul_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class MatMulCtPtOp : public OpKernel {
}
};

template <typename PtT, typename T, bool use_fast_rotations>
template <typename PtT, typename T>
class MatMulPtCtOp : public OpKernel {
using ModularInt = rlwe::MontgomeryInt<T>;
using Context = rlwe::RnsContext<ModularInt>;
Expand All @@ -557,8 +557,20 @@ class MatMulPtCtOp : public OpKernel {
using Encoder = rlwe::FiniteFieldEncoder<ModularInt>;
using RotationKey = rlwe::RnsGaloisKey<ModularInt>;

std::string reduction;
char const* galois_reduction = "galois";
char const* fast_reduction = "fast";
char const* no_reduction = "none";

public:
explicit MatMulPtCtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {}
explicit MatMulPtCtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {
OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("reduction", &reduction));
OP_REQUIRES(
op_ctx,
reduction == "galois" || reduction == "fast" || reduction == "none",
InvalidArgument("Invalid reduction attribute: ", reduction,
". Must be 'galois', 'fast', or 'none'."));
}

void Compute(OpKernelContext* op_ctx) override {
// Get the input tensors.
Expand All @@ -579,9 +591,7 @@ class MatMulPtCtOp : public OpKernel {

// Rotation keys are only required if fast rotations are enabled.
RotationKeyVariant<T> const* rotation_key_var = nullptr;
if constexpr (use_fast_rotations) {
(void)rotation_key_var;
} else {
if (reduction == galois_reduction) {
OP_REQUIRES_VALUE(rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 3));
OP_REQUIRES(
Expand All @@ -590,10 +600,12 @@ class MatMulPtCtOp : public OpKernel {
OP_REQUIRES_OK(op_ctx,
const_cast<RotationKeyVariant<T>*>(rotation_key_var)
->MaybeLazyDecode(shell_ctx_var->ct_context_));
} else {
(void)rotation_key_var;
}
std::vector<std::shared_ptr<RotationKey>> empty_rot_keys{};
std::vector<std::shared_ptr<RotationKey>> const& rot_keys =
use_fast_rotations ? empty_rot_keys : rotation_key_var->keys;
reduction == galois_reduction ? rotation_key_var->keys : empty_rot_keys;

// b is a vector of Polynomials so first dimension is the number of
// slots.
Expand Down Expand Up @@ -729,7 +741,25 @@ class MatMulPtCtOp : public OpKernel {
// Note the ciphertext rotations operate on each half of the
// ciphertext separately. So the max rotatation is by half the
// number of slots.
if constexpr (use_fast_rotations) {
if (reduction == galois_reduction) {
for (int shift = 1; shift < num_slots / 2; shift <<= 1) {
OP_REQUIRES(
op_ctx,
shift - 1 <
static_cast<int>(rot_keys.size()), // Skip key 0.
InvalidArgument("No key for shift of '", shift, "'"));
RotationKey const* k = rot_keys[shift].get();

// Rotate by the shift.
OP_REQUIRES_VALUE(auto ct_sub, op_ctx,
ct_result.Substitute(k->SubstitutionPower()));
OP_REQUIRES_VALUE(auto ct_rot, op_ctx, k->ApplyTo(ct_sub));

// Add to the sum.
OP_REQUIRES_OK(op_ctx, ct_result.AddInPlace(ct_rot));
}

} else if (reduction == fast_reduction) {
OP_REQUIRES_VALUE(
RnsPolynomial sum_component_zero, op_ctx,
ct_result.Component(0)); // deep copy to start the sum.
Expand Down Expand Up @@ -759,25 +789,8 @@ class MatMulPtCtOp : public OpKernel {
ct_result.PowerOfS(),
ct_result.Error() * ct_result.LogN(),
ct_result.ErrorParams()};

} else {
for (int shift = 1; shift < num_slots / 2; shift <<= 1) {
OP_REQUIRES(
op_ctx,
shift - 1 <
static_cast<int>(rot_keys.size()), // Skip key 0.
InvalidArgument("No key for shift of '", shift, "'"));
RotationKey const* k = rot_keys[shift].get();

// Rotate by the shift.
OP_REQUIRES_VALUE(auto ct_sub, op_ctx,
ct_result.Substitute(k->SubstitutionPower()));
OP_REQUIRES_VALUE(auto ct_rot, op_ctx, k->ApplyTo(ct_sub));

// Add to the sum.
OP_REQUIRES_OK(op_ctx, ct_result.AddInPlace(ct_rot));
}
}
// else if no reduction, do nothing.

// At this point we have one ciphertext per row of the plaintext
// matrix where every element in the ciphertext is the same value,
Expand Down Expand Up @@ -896,57 +909,28 @@ REGISTER_KERNEL_BUILDER(
// Matrix multiply plaintext and ciphertext.
REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint8>("Dtype"),
MatMulPtCtOp<uint8, uint64, false>);
MatMulPtCtOp<uint8, uint64>);
REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int8>("Dtype"),
MatMulPtCtOp<int8, uint64, false>);
MatMulPtCtOp<int8, uint64>);

REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint16>("Dtype"),
MatMulPtCtOp<uint16, uint64, false>);
MatMulPtCtOp<uint16, uint64>);
REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int16>("Dtype"),
MatMulPtCtOp<int16, uint64, false>);
MatMulPtCtOp<int16, uint64>);

REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint32>("Dtype"),
MatMulPtCtOp<uint32, uint64, false>);
MatMulPtCtOp<uint32, uint64>);
REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int32>("Dtype"),
MatMulPtCtOp<int32, uint64, false>);
MatMulPtCtOp<int32, uint64>);

REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint64>("Dtype"),
MatMulPtCtOp<uint64, uint64, false>);
MatMulPtCtOp<uint64, uint64>);
REGISTER_KERNEL_BUILDER(
Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int64>("Dtype"),
MatMulPtCtOp<int64, uint64, false>);

// Matrix multiply plaintext and ciphertext with fast rotations.
REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint8>("Dtype"),
MatMulPtCtOp<uint8, uint64, true>);
REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int8>("Dtype"),
MatMulPtCtOp<int8, uint64, true>);

REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint16>("Dtype"),
MatMulPtCtOp<uint16, uint64, true>);
REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int16>("Dtype"),
MatMulPtCtOp<int16, uint64, true>);

REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint32>("Dtype"),
MatMulPtCtOp<uint32, uint64, true>);
REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int32>("Dtype"),
MatMulPtCtOp<int32, uint64, true>);

REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<uint64>("Dtype"),
MatMulPtCtOp<uint64, uint64, true>);
REGISTER_KERNEL_BUILDER(
Name("FastMatMulPtCt64").Device(DEVICE_CPU).TypeConstraint<int64>("Dtype"),
MatMulPtCtOp<int64, uint64, true>);
MatMulPtCtOp<int64, uint64>);
Loading

0 comments on commit e9d7a83

Please sign in to comment.