Skip to content

Commit

Permalink
Rotation key encoding for networking and key caching.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 11, 2024
1 parent e971b9f commit 5b39de4
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tf_shell/cc/kernels/context_variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ContextVariant {
substitution_powers_.reserve(num_slots / 2);
for (uint shift = 0; shift < num_slots / 2; ++shift) {
substitution_powers_.push_back(sub_power);
sub_power *= base_power;
sub_power *= kSubstitutionBasePower;
sub_power %= two_n;
}

Expand Down
6 changes: 6 additions & 0 deletions tf_shell/cc/kernels/mul_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,12 @@ class MatMulPtCtOp : public OpKernel {
} else {
OP_REQUIRES_VALUE(rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 3));
OP_REQUIRES(
op_ctx, rotation_key_var != nullptr,
InvalidArgument("RotationKeyVariant did not unwrap successfully."));
OP_REQUIRES_OK(op_ctx,
const_cast<RotationKeyVariant<T>*>(rotation_key_var)
->MaybeLazyDecode(shell_ctx_var->ct_context_));
}
std::vector<std::shared_ptr<RotationKey>> empty_rot_keys{};
std::vector<std::shared_ptr<RotationKey>> const& rot_keys =
Expand Down
23 changes: 17 additions & 6 deletions tf_shell/cc/kernels/rotation_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ using tensorflow::uint8;
using tensorflow::Variant;
using tensorflow::errors::InvalidArgument;

constexpr int kLogGadgetBase = 4;
constexpr rlwe::PrngType kPrngType = rlwe::PRNG_TYPE_HKDF;

template <typename T>
Expand Down Expand Up @@ -85,6 +84,10 @@ class RotationKeyGenOp : public OpKernel {
// Create the output variant
RotationKeyVariant<T> v_out;

// The RotationKeys internally hold a pointer to the secret key's moduli.
// Thus the context pointer must also come from the secret key.
v_out.ct_context = secret_key_var->ct_context;

// Create the gadget.
int level = shell_ctx->NumMainPrimeModuli() - 1;
OP_REQUIRES_VALUE(auto q_hats, op_ctx,
Expand Down Expand Up @@ -114,9 +117,9 @@ class RotationKeyGenOp : public OpKernel {
// Skip rotation key at zero, it does not rotate.
if (start == 0) ++start;

uint sub_power = base_power;
uint sub_power = kSubstitutionBasePower;
for (int i = 1; i < start; ++i) {
sub_power *= base_power;
sub_power *= kSubstitutionBasePower;
sub_power %= two_n;
}

Expand All @@ -125,8 +128,8 @@ class RotationKeyGenOp : public OpKernel {
RotationKey k, op_ctx,
RotationKey::CreateForBgv(*secret_key, sub_power, variance,
gadget_ptr.get(), t, kPrngType));
v_out.keys[i] = std::move(std::make_shared<RotationKey>(k));
sub_power *= base_power;
v_out.keys[i] = std::make_shared<RotationKey>(std::move(k));
sub_power *= kSubstitutionBasePower;
sub_power %= two_n;
}
};
Expand Down Expand Up @@ -163,6 +166,8 @@ class RollOp : public OpKernel {
OP_REQUIRES(
op_ctx, rotation_key_var != nullptr,
InvalidArgument("RotationKeyVariant did not unwrap successfully."));
OP_REQUIRES_OK(op_ctx, const_cast<RotationKeyVariant<T>*>(rotation_key_var)
->MaybeLazyDecode(shell_ctx_var->ct_context_));
std::vector<std::shared_ptr<RotationKey>> const& keys =
rotation_key_var->keys;

Expand Down Expand Up @@ -282,6 +287,8 @@ class ReduceSumByRotationCtOp : public OpKernel {
OP_REQUIRES(
op_ctx, rotation_key_var != nullptr,
InvalidArgument("RotationKeyVariant did not unwrap successfully."));
OP_REQUIRES_OK(op_ctx, const_cast<RotationKeyVariant<T>*>(rotation_key_var)
->MaybeLazyDecode(shell_ctx_var->ct_context_));
std::vector<std::shared_ptr<RotationKey>> const& keys =
rotation_key_var->keys;

Expand Down Expand Up @@ -681,4 +688,8 @@ REGISTER_KERNEL_BUILDER(Name("ReduceSumWithModulusPt")
ReduceSumWithModulusPtOp<int8, uint64>);

REGISTER_KERNEL_BUILDER(Name("ReduceSumCt64").Device(DEVICE_CPU),
ReduceSumCtOp<uint64>);
ReduceSumCtOp<uint64>);

typedef RotationKeyVariant<uint64> RotationKeyVariantUint64;
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RotationKeyVariantUint64,
RotationKeyVariantUint64::kTypeName);
7 changes: 3 additions & 4 deletions tf_shell/cc/kernels/rotation_kernels_fast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ using tensorflow::uint8;
using tensorflow::Variant;
using tensorflow::errors::InvalidArgument;

constexpr int kLogGadgetBase = 4;
constexpr rlwe::PrngType kPrngType = rlwe::PRNG_TYPE_HKDF;

// Fast Rotation Kernels:
Expand Down Expand Up @@ -118,9 +117,9 @@ class FastRotationKeyGenOp : public OpKernel {
keys.push_back(key_sub_i);

for (uint i = 1; i < num_slots / 2; ++i) {
OP_REQUIRES_VALUE(
key_sub_i, op_ctx,
key_sub_i.Substitute(base_power, shell_ctx->MainPrimeModuli()));
OP_REQUIRES_VALUE(key_sub_i, op_ctx,
key_sub_i.Substitute(kSubstitutionBasePower,
shell_ctx->MainPrimeModuli()));
keys.push_back(key_sub_i);
}

Expand Down
122 changes: 118 additions & 4 deletions tf_shell/cc/kernels/rotation_variants.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "utils.h"

using tensorflow::VariantTensorData;

template <typename T>
class RotationKeyVariant {
using ModularInt = rlwe::MontgomeryInt<T>;
using Context = rlwe::RnsContext<ModularInt>;
using Gadget = rlwe::RnsGadget<ModularInt>;
using RotationKey = rlwe::RnsGaloisKey<ModularInt>;

Expand All @@ -37,18 +39,130 @@ class RotationKeyVariant {

std::string TypeName() const { return kTypeName; }

// TODO(jchoncholas): implement for networking
void Encode(VariantTensorData* data) const {};
void Encode(VariantTensorData* data) const {
auto async_key_strs = key_strs; // Make sure key string is not deallocated.
auto async_ct_context = ct_context;

if (async_ct_context == nullptr) {
// If the context is null, this may have been decoded but not lazy decoded
// yet. In this case, directly encode the key strings.
if (async_key_strs == nullptr) {
std::cout << "ERROR: Rotation key not set, cannot encode." << std::endl;
return;
}
data->tensors_.reserve(async_key_strs->size());
for (auto const& key_str : *async_key_strs) {
data->tensors_.push_back(Tensor(key_str));
}
}

// Skip first rotation key at index 0.
data->tensors_.reserve(keys.size() - 1);

for (int i = 1; i < keys.size(); i++) {
auto serialized_key_or = keys[i]->Serialize();
if (!serialized_key_or.ok()) {
std::cout << "ERROR: Failed to serialize rotation key: "
<< serialized_key_or.status();
return;
}
std::string serialized_key;
serialized_key_or.value().SerializeToString(&serialized_key);
data->tensors_.push_back(Tensor(serialized_key));
}
};

bool Decode(VariantTensorData const& data) {
if (data.tensors_.size() < 1) {
std::cout << "ERROR: Not enough tensors to deserialize rotation key."
<< std::endl;
return false;
}

if (key_strs != nullptr) {
std::cout << "ERROR: Rotation key already decoded." << std::endl;
return false;
}

// TODO(jchoncholas): implement for networking
bool Decode(VariantTensorData const& data) { return false; };
size_t num_keys = data.tensors_.size();
std::vector<std::string> building_key_strs;
building_key_strs.reserve(num_keys);

for (size_t i = 0; i < num_keys; ++i) {
std::string const serialized_key(
data.tensors_[i].scalar<tstring>()().begin(),
data.tensors_[i].scalar<tstring>()().end());

building_key_strs.push_back(std::move(serialized_key));
}

key_strs = std::make_shared<std::vector<std::string>>(
std::move(building_key_strs));

return true;
};

Status MaybeLazyDecode(std::shared_ptr<Context const> ct_context_) {
std::lock_guard<std::mutex> lock(mutex.mutex);

// If the keys have already been fully decoded, nothing to do.
if (ct_context != nullptr) {
return OkStatus();
}

// Re-create the gadget.
int level = ct_context_->NumMainPrimeModuli() - 1;
TF_ASSIGN_OR_RETURN(auto q_hats,
ct_context_->MainPrimeModulusComplements(level));
TF_ASSIGN_OR_RETURN(auto q_hat_invs,
ct_context_->MainPrimeModulusCrtFactors(level));
std::vector<size_t> log_bs(ct_context_->NumMainPrimeModuli(),
kLogGadgetBase);
TF_ASSIGN_OR_RETURN(
auto raw_gadget,
Gadget::Create(ct_context_->LogN(), log_bs, q_hats, q_hat_invs,
ct_context_->MainPrimeModuli()));
gadget = std::make_shared<Gadget>(std::move(raw_gadget));

// Decode the keys.
// The first key is skipped as it corresponds to a rotation by 0.
keys.reserve(key_strs->size() + 1);
keys.push_back(nullptr);

for (auto const& key_str : *key_strs) {
rlwe::SerializedRnsGaloisKey serialized_key;
bool ok = serialized_key.ParseFromString(key_str);
if (!ok) {
return InvalidArgument("Failed to parse rotation key.");
}

// Using the moduli, reconstruct the key polynomial.
TF_ASSIGN_OR_RETURN(RotationKey key, RotationKey::Deserialize(
serialized_key, gadget.get(),
ct_context_->MainPrimeModuli()));

keys.push_back(std::make_shared<RotationKey>(std::move(key)));
}

// Hold a pointer to the context for future encoding.
ct_context = ct_context_;

// Clear the key strings.
key_strs = nullptr;

return OkStatus();
};

std::string DebugString() const { return "ShellRotationKeyVariant"; }

variant_mutex mutex;
// Each key holds a raw pointer to gadget. Use a smart pointer to the gadget
// to help with copy semantics.
std::shared_ptr<Gadget> gadget;
// Rotation keys do not have default constructors, so use a shared pointer.
std::vector<std::shared_ptr<RotationKey>> keys;
std::shared_ptr<std::vector<std::string>> key_strs;
std::shared_ptr<Context const> ct_context;
};

template <typename T>
Expand Down
5 changes: 5 additions & 0 deletions tf_shell/cc/kernels/segment_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ class UnsortedSegmentReductionOp : public OpKernel {
Tensor const& num_segments = context->input(3);
OP_REQUIRES_VALUE(RotationKeyVariant<T> const* rotation_key_var, context,
GetVariant<RotationKeyVariant<T>>(context, 4));
OP_REQUIRES(
context, rotation_key_var != nullptr,
InvalidArgument("RotationKeyVariant did not unwrap successfully."));
OP_REQUIRES_OK(context, const_cast<RotationKeyVariant<T>*>(rotation_key_var)
->MaybeLazyDecode(shell_ctx_var->ct_context_));

OP_REQUIRES_OK(context, ValidateUnsortedSegmentReduction(
this, context, shell_ctx_var, data, segment_ids,
Expand Down
5 changes: 4 additions & 1 deletion tf_shell/cc/kernels/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ using tensorflow::errors::InvalidArgument;
using tensorflow::errors::Unimplemented;

// The substitution power for Galois rotation by one slot.
constexpr int base_power = 5;
constexpr int kSubstitutionBasePower = 5;

// The base of the Galois rotation gadget.
constexpr int kLogGadgetBase = 4;

// A mutex for use with variants with appropriate copy/assign.
struct variant_mutex {
Expand Down

0 comments on commit 5b39de4

Please sign in to comment.