Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory copies of serialized buffers #1175

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class LweSecretKey {
static LweSecretKey
fromProto(const Message<concreteprotocol::LweSecretKey> &proto);

static LweSecretKey fromProto(concreteprotocol::LweSecretKey::Reader reader);

Message<concreteprotocol::LweSecretKey> toProto() const;

const uint64_t *getRawPtr() const;
Expand Down Expand Up @@ -95,6 +97,10 @@ class LweBootstrapKey {
static LweBootstrapKey
fromProto(const Message<concreteprotocol::LweBootstrapKey> &proto);

/// @brief Initialize the key from a reader.
static LweBootstrapKey
fromProto(concreteprotocol::LweBootstrapKey::Reader reader);

/// @brief Returns the serialized form of the key.
Message<concreteprotocol::LweBootstrapKey> toProto() const;

Expand Down Expand Up @@ -147,6 +153,10 @@ class LweKeyswitchKey {
static LweKeyswitchKey
fromProto(const Message<concreteprotocol::LweKeyswitchKey> &proto);

/// @brief Initialize the key from a reader.
static LweKeyswitchKey
fromProto(concreteprotocol::LweKeyswitchKey::Reader reader);

/// @brief Returns the serialized form of the key.
Message<concreteprotocol::LweKeyswitchKey> toProto() const;

Expand Down Expand Up @@ -199,6 +209,9 @@ class PackingKeyswitchKey {
static PackingKeyswitchKey
fromProto(const Message<concreteprotocol::PackingKeyswitchKey> &proto);

static PackingKeyswitchKey
fromProto(concreteprotocol::PackingKeyswitchKey::Reader reader);

Message<concreteprotocol::PackingKeyswitchKey> toProto() const;

const uint64_t *getRawPtr() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct ClientKeyset {
static ClientKeyset
fromProto(const Message<concreteprotocol::ClientKeyset> &proto);

static ClientKeyset fromProto(concreteprotocol::ClientKeyset::Reader reader);

Message<concreteprotocol::ClientKeyset> toProto() const;
};

Expand All @@ -43,6 +45,7 @@ struct ServerKeyset {

static ServerKeyset
fromProto(const Message<concreteprotocol::ServerKeyset> &proto);
static ServerKeyset fromProto(concreteprotocol::ServerKeyset::Reader reader);

Message<concreteprotocol::ServerKeyset> toProto() const;
};
Expand Down Expand Up @@ -73,6 +76,7 @@ struct Keyset {
: server(server), client(client) {}

static Keyset fromProto(const Message<concreteprotocol::Keyset> &proto);
static Keyset fromProto(concreteprotocol::Keyset::Reader reader);

Message<concreteprotocol::Keyset> toProto() const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ template <typename MessageType> struct Message {
message = regionBuilder->initRoot<MessageType>();
}

Message(const typename MessageType::Reader &reader) : message(nullptr) {
explicit Message(const typename MessageType::Reader &reader)
: message(nullptr) {
regionBuilder = new capnp::MallocMessageBuilder(
std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE),
capnp::AllocationStrategy::FIXED_SIZE);
Expand Down Expand Up @@ -308,7 +309,12 @@ vectorToProtoPayload(const std::vector<T> &input) {
template <typename T>
std::vector<T>
protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
return protoPayloadToVector<T>(input.asReader());
}

template <typename T>
std::vector<T> protoPayloadToVector(concreteprotocol::Payload::Reader reader) {
auto payloadData = reader.getData();
auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
size_t totalPayloadSize = 0;
for (auto blob : payloadData) {
Expand All @@ -331,7 +337,13 @@ protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
template <typename T>
std::shared_ptr<std::vector<T>>
protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
return protoPayloadToSharedVector<T>(input.asReader());
}

template <typename T>
std::shared_ptr<std::vector<T>>
protoPayloadToSharedVector(concreteprotocol::Payload::Reader reader) {
auto payloadData = reader.getData();
size_t elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
size_t totalPayloadSize = 0;
for (auto blob : payloadData) {
Expand All @@ -353,6 +365,8 @@ protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
/// dimensions.
std::vector<size_t>
protoShapeToDimensions(const Message<concreteprotocol::Shape> &shape);
std::vector<size_t>
protoShapeToDimensions(concreteprotocol::Shape::Reader reader);

/// Helper function turning a protocol `Shape` object into a vector of
/// dimensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct Value {

bool
isCompatibleWithShape(const Message<concreteprotocol::Shape> &shape) const;
bool isCompatibleWithShape(concreteprotocol::Shape::Reader reader) const;

bool isScalar() const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ class TestProgram {
}
OUTCOME_TRY(auto lib, getLibrary());
OUTCOME_TRY(auto programInfo, lib.getProgramInfo());
auto keysetInfo =
(Message<concreteprotocol::KeysetInfo>)programInfo.asReader()
.getKeyset();
if (tryCache) {
OUTCOME_TRY(keyset, getTestKeySetCachePtr()->getKeyset(
programInfo.asReader().getKeyset(), secretSeed,
encryptionSeed));
keysetInfo, secretSeed, encryptionSeed));
} else {
auto encryptionCsprng = csprng::EncryptionCSPRNG(encryptionSeed);
auto secretCsprng = csprng::SecretCSPRNG(secretSeed);
Message<concreteprotocol::KeysetInfo> keysetInfo =
programInfo.asReader().getKeyset();
keyset = Keyset(keysetInfo, secretCsprng, encryptionCsprng);
}
return outcome::success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto secretKeys = std::vector<LweSecretKeyParam>();
for (auto key : keysetInfo.asReader().getLweSecretKeys()) {
secretKeys.push_back(LweSecretKeyParam{key});
secretKeys.push_back(LweSecretKeyParam{
(Message<concreteprotocol::LweSecretKeyInfo>)key});
}
return secretKeys;
},
Expand All @@ -1117,7 +1118,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto bootstrapKeys = std::vector<BootstrapKeyParam>();
for (auto key : keysetInfo.asReader().getLweBootstrapKeys()) {
bootstrapKeys.push_back(BootstrapKeyParam{key});
bootstrapKeys.push_back(BootstrapKeyParam{
(Message<concreteprotocol::LweBootstrapKeyInfo>)key});
}
return bootstrapKeys;
},
Expand All @@ -1127,7 +1129,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto keyswitchKeys = std::vector<KeyswitchKeyParam>();
for (auto key : keysetInfo.asReader().getLweKeyswitchKeys()) {
keyswitchKeys.push_back(KeyswitchKeyParam{key});
keyswitchKeys.push_back(KeyswitchKeyParam{
(Message<concreteprotocol::LweKeyswitchKeyInfo>)key});
}
return keyswitchKeys;
},
Expand All @@ -1137,7 +1140,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto packingKeyswitchKeys = std::vector<PackingKeyswitchKeyParam>();
for (auto key : keysetInfo.asReader().getPackingKeyswitchKeys()) {
packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{key});
packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{
(Message<concreteprotocol::PackingKeyswitchKeyInfo>)key});
}
return packingKeyswitchKeys;
},
Expand Down Expand Up @@ -1220,13 +1224,13 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(
"get_type_info",
[](GateInfo &gate) -> TypeInfo {
return {gate.asReader().getTypeInfo()};
return {(TypeInfo)gate.asReader().getTypeInfo()};
},
"Return the type associated to the gate.")
.def(
"get_raw_info",
[](GateInfo &gate) -> RawInfo {
return {gate.asReader().getRawInfo()};
return {(RawInfo)gate.asReader().getRawInfo()};
},
"Return the raw type associated to the gate.")
.doc() = "Informations describing a circuit gate (input or output).";
Expand All @@ -1247,7 +1251,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CircuitInfo &circuit) -> std::vector<GateInfo> {
auto output = std::vector<GateInfo>();
for (auto gate : circuit.asReader().getInputs()) {
output.push_back({gate});
output.push_back({(Message<concreteprotocol::GateInfo>)gate});
}
return output;
},
Expand All @@ -1257,7 +1261,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CircuitInfo &circuit) -> std::vector<GateInfo> {
auto output = std::vector<GateInfo>();
for (auto gate : circuit.asReader().getOutputs()) {
output.push_back({gate});
output.push_back({(Message<concreteprotocol::GateInfo>)gate});
}
return output;
},
Expand Down Expand Up @@ -1415,7 +1419,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(
"get_keyset_info",
[](ProgramInfo &programInfo) -> KeysetInfo {
return programInfo.programInfo.asReader().getKeyset();
return (KeysetInfo)programInfo.programInfo.asReader().getKeyset();
},
"Return the keyset info associated to the program.")
.def(
Expand All @@ -1424,7 +1428,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
auto output = std::vector<CircuitInfo>();
for (auto circuit :
programInfo.programInfo.asReader().getCircuits()) {
output.push_back(circuit);
output.push_back((CircuitInfo)circuit);
}
return output;
},
Expand All @@ -1435,7 +1439,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
for (auto circuit :
programInfo.programInfo.asReader().getCircuits()) {
if (circuit.getName() == name) {
return circuit;
return (CircuitInfo)circuit;
}
}
throw std::runtime_error("couldn't find circuit.");
Expand Down Expand Up @@ -1552,7 +1556,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Failed to deserialize server keyset." +
maybeError.as_failure().error().mesg);
}
return ServerKeyset::fromProto(serverKeysetProto);
return ServerKeyset::fromProto(serverKeysetProto.asReader());
},
"Deserialize a ServerKeyset from bytes.", arg("bytes"))
.def(
Expand Down Expand Up @@ -1604,17 +1608,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
GET_OR_THROW_RESULT(
Keyset keyset,
(*cache).getKeyset(
programInfo.programInfo.asReader().getKeyset(),
(KeysetInfo)programInfo.programInfo.asReader()
.getKeyset(),
secretSeed, encryptionSeed,
initialLweSecretKeys.value()));
return std::make_unique<Keyset>(std::move(keyset));
} else {
::concretelang::csprng::SecretCSPRNG secCsprng(secretSeed);
::concretelang::csprng::EncryptionCSPRNG encCsprng(
encryptionSeed);
auto keyset =
Keyset(programInfo.programInfo.asReader().getKeyset(),
secCsprng, encCsprng, initialLweSecretKeys.value());
auto keyset = Keyset(
(KeysetInfo)programInfo.programInfo.asReader().getKeyset(),
secCsprng, encCsprng, initialLweSecretKeys.value());
return std::make_unique<Keyset>(std::move(keyset));
}
}),
Expand Down Expand Up @@ -1652,7 +1657,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Failed to deserialize keyset." +
maybeError.as_failure().error().mesg);
}
auto keyset = Keyset::fromProto(keysetProto);
auto keyset = Keyset::fromProto(std::move(keysetProto));
return std::make_unique<Keyset>(std::move(keyset));
},
"Deserialize a Keyset from a file.", arg("path"))
Expand All @@ -1670,6 +1675,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return pybind11::bytes(keySetSerialize(keySet));
},
"Serialize a Keyset to bytes.")
.def(
youben11 marked this conversation as resolved.
Show resolved Hide resolved
"serialize_to_file",
[](Keyset &keySet, const std::string path) {
std::ofstream ofs;
ofs.open(path);
if (!ofs.good()) {
throw std::runtime_error("Failed to open keyset file " + path);
}
auto keysetProto = keySet.toProto();
auto maybeBuffer = keysetProto.writeBinaryToOstream(ofs);
if (maybeBuffer.has_failure()) {
throw std::runtime_error("Failed to serialize keys.");
}
},
"Serialize a Keyset to bytes.")
.def(
"serialize_lwe_secret_key_as_glwe",
[](Keyset &keyset, size_t keyIndex, size_t glwe_dimension,
Expand Down Expand Up @@ -2034,8 +2054,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(

GET_OR_THROW_RESULT(auto pi, library.getProgramInfo());
GET_OR_THROW_RESULT(
auto result, ServerProgram::load(pi.asReader(), sharedLibPath,
useSimulation));
auto result,
ServerProgram::load(
(Message<concreteprotocol::ProgramInfo>)pi.asReader(),
sharedLibPath, useSimulation));
return result;
}),
arg("library"), arg("use_simulation"))
Expand All @@ -2061,7 +2083,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Unknown position.");
}
auto info = circuit.getCircuitInfo().asReader().getInputs()[pos];
auto typeTransformer = getPythonTypeTransformer(info);
auto typeTransformer = getPythonTypeTransformer((GateInfo)info);
GET_OR_THROW_RESULT(
auto ok, circuit.prepareInput(typeTransformer(arg), pos));
return ok;
Expand All @@ -2084,7 +2106,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Unknown position.");
}
auto info = circuit.getCircuitInfo().asReader().getInputs()[pos];
auto typeTransformer = getPythonTypeTransformer(info);
auto typeTransformer = getPythonTypeTransformer((GateInfo)info);
GET_OR_THROW_RESULT(auto ok, circuit.simulatePrepareInput(
typeTransformer(arg), pos));
return ok;
Expand Down
Loading
Loading