Skip to content

Commit

Permalink
Merge pull request #1175 from zama-ai/fix/keyset_deser
Browse files Browse the repository at this point in the history
Reduce memory copies of serialized buffers
  • Loading branch information
youben11 authored Dec 10, 2024
2 parents 2a1f449 + 98c60ba commit 0dd8bec
Show file tree
Hide file tree
Showing 17 changed files with 311 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class LweSecretKey {
static LweSecretKey
fromProto(const Message<concreteprotocol::LweSecretKey> &proto);

static LweSecretKey fromProto(concreteprotocol::LweSecretKey::Reader reader);

Message<concreteprotocol::LweSecretKey> toProto() const;

const uint64_t *getRawPtr() const;
Expand Down Expand Up @@ -95,6 +97,10 @@ class LweBootstrapKey {
static LweBootstrapKey
fromProto(const Message<concreteprotocol::LweBootstrapKey> &proto);

/// @brief Initialize the key from a reader.
static LweBootstrapKey
fromProto(concreteprotocol::LweBootstrapKey::Reader reader);

/// @brief Returns the serialized form of the key.
Message<concreteprotocol::LweBootstrapKey> toProto() const;

Expand Down Expand Up @@ -147,6 +153,10 @@ class LweKeyswitchKey {
static LweKeyswitchKey
fromProto(const Message<concreteprotocol::LweKeyswitchKey> &proto);

/// @brief Initialize the key from a reader.
static LweKeyswitchKey
fromProto(concreteprotocol::LweKeyswitchKey::Reader reader);

/// @brief Returns the serialized form of the key.
Message<concreteprotocol::LweKeyswitchKey> toProto() const;

Expand Down Expand Up @@ -199,6 +209,9 @@ class PackingKeyswitchKey {
static PackingKeyswitchKey
fromProto(const Message<concreteprotocol::PackingKeyswitchKey> &proto);

static PackingKeyswitchKey
fromProto(concreteprotocol::PackingKeyswitchKey::Reader reader);

Message<concreteprotocol::PackingKeyswitchKey> toProto() const;

const uint64_t *getRawPtr() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct ClientKeyset {
static ClientKeyset
fromProto(const Message<concreteprotocol::ClientKeyset> &proto);

static ClientKeyset fromProto(concreteprotocol::ClientKeyset::Reader reader);

Message<concreteprotocol::ClientKeyset> toProto() const;
};

Expand All @@ -43,6 +45,7 @@ struct ServerKeyset {

static ServerKeyset
fromProto(const Message<concreteprotocol::ServerKeyset> &proto);
static ServerKeyset fromProto(concreteprotocol::ServerKeyset::Reader reader);

Message<concreteprotocol::ServerKeyset> toProto() const;
};
Expand Down Expand Up @@ -73,6 +76,7 @@ struct Keyset {
: server(server), client(client) {}

static Keyset fromProto(const Message<concreteprotocol::Keyset> &proto);
static Keyset fromProto(concreteprotocol::Keyset::Reader reader);

Message<concreteprotocol::Keyset> toProto() const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ template <typename MessageType> struct Message {
message = regionBuilder->initRoot<MessageType>();
}

Message(const typename MessageType::Reader &reader) : message(nullptr) {
explicit Message(const typename MessageType::Reader &reader)
: message(nullptr) {
regionBuilder = new capnp::MallocMessageBuilder(
std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE),
capnp::AllocationStrategy::FIXED_SIZE);
Expand Down Expand Up @@ -308,7 +309,12 @@ vectorToProtoPayload(const std::vector<T> &input) {
template <typename T>
std::vector<T>
protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
return protoPayloadToVector<T>(input.asReader());
}

template <typename T>
std::vector<T> protoPayloadToVector(concreteprotocol::Payload::Reader reader) {
auto payloadData = reader.getData();
auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
size_t totalPayloadSize = 0;
for (auto blob : payloadData) {
Expand All @@ -331,7 +337,13 @@ protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
template <typename T>
std::shared_ptr<std::vector<T>>
protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
return protoPayloadToSharedVector<T>(input.asReader());
}

template <typename T>
std::shared_ptr<std::vector<T>>
protoPayloadToSharedVector(concreteprotocol::Payload::Reader reader) {
auto payloadData = reader.getData();
size_t elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
size_t totalPayloadSize = 0;
for (auto blob : payloadData) {
Expand All @@ -353,6 +365,8 @@ protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
/// dimensions.
std::vector<size_t>
protoShapeToDimensions(const Message<concreteprotocol::Shape> &shape);
std::vector<size_t>
protoShapeToDimensions(concreteprotocol::Shape::Reader reader);

/// Helper function turning a protocol `Shape` object into a vector of
/// dimensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct Value {

bool
isCompatibleWithShape(const Message<concreteprotocol::Shape> &shape) const;
bool isCompatibleWithShape(concreteprotocol::Shape::Reader reader) const;

bool isScalar() const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ class TestProgram {
}
OUTCOME_TRY(auto lib, getLibrary());
OUTCOME_TRY(auto programInfo, lib.getProgramInfo());
auto keysetInfo =
(Message<concreteprotocol::KeysetInfo>)programInfo.asReader()
.getKeyset();
if (tryCache) {
OUTCOME_TRY(keyset, getTestKeySetCachePtr()->getKeyset(
programInfo.asReader().getKeyset(), secretSeed,
encryptionSeed));
keysetInfo, secretSeed, encryptionSeed));
} else {
auto encryptionCsprng = csprng::EncryptionCSPRNG(encryptionSeed);
auto secretCsprng = csprng::SecretCSPRNG(secretSeed);
Message<concreteprotocol::KeysetInfo> keysetInfo =
programInfo.asReader().getKeyset();
keyset = Keyset(keysetInfo, secretCsprng, encryptionCsprng);
}
return outcome::success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto secretKeys = std::vector<LweSecretKeyParam>();
for (auto key : keysetInfo.asReader().getLweSecretKeys()) {
secretKeys.push_back(LweSecretKeyParam{key});
secretKeys.push_back(LweSecretKeyParam{
(Message<concreteprotocol::LweSecretKeyInfo>)key});
}
return secretKeys;
},
Expand All @@ -1117,7 +1118,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto bootstrapKeys = std::vector<BootstrapKeyParam>();
for (auto key : keysetInfo.asReader().getLweBootstrapKeys()) {
bootstrapKeys.push_back(BootstrapKeyParam{key});
bootstrapKeys.push_back(BootstrapKeyParam{
(Message<concreteprotocol::LweBootstrapKeyInfo>)key});
}
return bootstrapKeys;
},
Expand All @@ -1127,7 +1129,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto keyswitchKeys = std::vector<KeyswitchKeyParam>();
for (auto key : keysetInfo.asReader().getLweKeyswitchKeys()) {
keyswitchKeys.push_back(KeyswitchKeyParam{key});
keyswitchKeys.push_back(KeyswitchKeyParam{
(Message<concreteprotocol::LweKeyswitchKeyInfo>)key});
}
return keyswitchKeys;
},
Expand All @@ -1137,7 +1140,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto packingKeyswitchKeys = std::vector<PackingKeyswitchKeyParam>();
for (auto key : keysetInfo.asReader().getPackingKeyswitchKeys()) {
packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{key});
packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{
(Message<concreteprotocol::PackingKeyswitchKeyInfo>)key});
}
return packingKeyswitchKeys;
},
Expand Down Expand Up @@ -1220,13 +1224,13 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(
"get_type_info",
[](GateInfo &gate) -> TypeInfo {
return {gate.asReader().getTypeInfo()};
return {(TypeInfo)gate.asReader().getTypeInfo()};
},
"Return the type associated to the gate.")
.def(
"get_raw_info",
[](GateInfo &gate) -> RawInfo {
return {gate.asReader().getRawInfo()};
return {(RawInfo)gate.asReader().getRawInfo()};
},
"Return the raw type associated to the gate.")
.doc() = "Informations describing a circuit gate (input or output).";
Expand All @@ -1247,7 +1251,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CircuitInfo &circuit) -> std::vector<GateInfo> {
auto output = std::vector<GateInfo>();
for (auto gate : circuit.asReader().getInputs()) {
output.push_back({gate});
output.push_back({(Message<concreteprotocol::GateInfo>)gate});
}
return output;
},
Expand All @@ -1257,7 +1261,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CircuitInfo &circuit) -> std::vector<GateInfo> {
auto output = std::vector<GateInfo>();
for (auto gate : circuit.asReader().getOutputs()) {
output.push_back({gate});
output.push_back({(Message<concreteprotocol::GateInfo>)gate});
}
return output;
},
Expand Down Expand Up @@ -1415,7 +1419,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(
"get_keyset_info",
[](ProgramInfo &programInfo) -> KeysetInfo {
return programInfo.programInfo.asReader().getKeyset();
return (KeysetInfo)programInfo.programInfo.asReader().getKeyset();
},
"Return the keyset info associated to the program.")
.def(
Expand All @@ -1424,7 +1428,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
auto output = std::vector<CircuitInfo>();
for (auto circuit :
programInfo.programInfo.asReader().getCircuits()) {
output.push_back(circuit);
output.push_back((CircuitInfo)circuit);
}
return output;
},
Expand All @@ -1435,7 +1439,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
for (auto circuit :
programInfo.programInfo.asReader().getCircuits()) {
if (circuit.getName() == name) {
return circuit;
return (CircuitInfo)circuit;
}
}
throw std::runtime_error("couldn't find circuit.");
Expand Down Expand Up @@ -1552,7 +1556,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Failed to deserialize server keyset." +
maybeError.as_failure().error().mesg);
}
return ServerKeyset::fromProto(serverKeysetProto);
return ServerKeyset::fromProto(serverKeysetProto.asReader());
},
"Deserialize a ServerKeyset from bytes.", arg("bytes"))
.def(
Expand Down Expand Up @@ -1604,17 +1608,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
GET_OR_THROW_RESULT(
Keyset keyset,
(*cache).getKeyset(
programInfo.programInfo.asReader().getKeyset(),
(KeysetInfo)programInfo.programInfo.asReader()
.getKeyset(),
secretSeed, encryptionSeed,
initialLweSecretKeys.value()));
return std::make_unique<Keyset>(std::move(keyset));
} else {
::concretelang::csprng::SecretCSPRNG secCsprng(secretSeed);
::concretelang::csprng::EncryptionCSPRNG encCsprng(
encryptionSeed);
auto keyset =
Keyset(programInfo.programInfo.asReader().getKeyset(),
secCsprng, encCsprng, initialLweSecretKeys.value());
auto keyset = Keyset(
(KeysetInfo)programInfo.programInfo.asReader().getKeyset(),
secCsprng, encCsprng, initialLweSecretKeys.value());
return std::make_unique<Keyset>(std::move(keyset));
}
}),
Expand Down Expand Up @@ -1652,7 +1657,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Failed to deserialize keyset." +
maybeError.as_failure().error().mesg);
}
auto keyset = Keyset::fromProto(keysetProto);
auto keyset = Keyset::fromProto(std::move(keysetProto));
return std::make_unique<Keyset>(std::move(keyset));
},
"Deserialize a Keyset from a file.", arg("path"))
Expand All @@ -1670,6 +1675,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return pybind11::bytes(keySetSerialize(keySet));
},
"Serialize a Keyset to bytes.")
.def(
"serialize_to_file",
[](Keyset &keySet, const std::string path) {
std::ofstream ofs;
ofs.open(path);
if (!ofs.good()) {
throw std::runtime_error("Failed to open keyset file " + path);
}
auto keysetProto = keySet.toProto();
auto maybeBuffer = keysetProto.writeBinaryToOstream(ofs);
if (maybeBuffer.has_failure()) {
throw std::runtime_error("Failed to serialize keys.");
}
},
"Serialize a Keyset to bytes.")
.def(
"serialize_lwe_secret_key_as_glwe",
[](Keyset &keyset, size_t keyIndex, size_t glwe_dimension,
Expand Down Expand Up @@ -2034,8 +2054,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(

GET_OR_THROW_RESULT(auto pi, library.getProgramInfo());
GET_OR_THROW_RESULT(
auto result, ServerProgram::load(pi.asReader(), sharedLibPath,
useSimulation));
auto result,
ServerProgram::load(
(Message<concreteprotocol::ProgramInfo>)pi.asReader(),
sharedLibPath, useSimulation));
return result;
}),
arg("library"), arg("use_simulation"))
Expand All @@ -2061,7 +2083,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Unknown position.");
}
auto info = circuit.getCircuitInfo().asReader().getInputs()[pos];
auto typeTransformer = getPythonTypeTransformer(info);
auto typeTransformer = getPythonTypeTransformer((GateInfo)info);
GET_OR_THROW_RESULT(
auto ok, circuit.prepareInput(typeTransformer(arg), pos));
return ok;
Expand All @@ -2084,7 +2106,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Unknown position.");
}
auto info = circuit.getCircuitInfo().asReader().getInputs()[pos];
auto typeTransformer = getPythonTypeTransformer(info);
auto typeTransformer = getPythonTypeTransformer((GateInfo)info);
GET_OR_THROW_RESULT(auto ok, circuit.simulatePrepareInput(
typeTransformer(arg), pos));
return ok;
Expand Down
Loading

0 comments on commit 0dd8bec

Please sign in to comment.