diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h index 1f35c6f373..f35e037cff 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h @@ -53,6 +53,8 @@ class LweSecretKey { static LweSecretKey fromProto(const Message &proto); + static LweSecretKey fromProto(concreteprotocol::LweSecretKey::Reader reader); + Message toProto() const; const uint64_t *getRawPtr() const; @@ -95,6 +97,10 @@ class LweBootstrapKey { static LweBootstrapKey fromProto(const Message &proto); + /// @brief Initialize the key from a reader. + static LweBootstrapKey + fromProto(concreteprotocol::LweBootstrapKey::Reader reader); + /// @brief Returns the serialized form of the key. Message toProto() const; @@ -147,6 +153,10 @@ class LweKeyswitchKey { static LweKeyswitchKey fromProto(const Message &proto); + /// @brief Initialize the key from a reader. + static LweKeyswitchKey + fromProto(concreteprotocol::LweKeyswitchKey::Reader reader); + /// @brief Returns the serialized form of the key. Message toProto() const; @@ -199,6 +209,9 @@ class PackingKeyswitchKey { static PackingKeyswitchKey fromProto(const Message &proto); + static PackingKeyswitchKey + fromProto(concreteprotocol::PackingKeyswitchKey::Reader reader); + Message toProto() const; const uint64_t *getRawPtr() const; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h index 0e84df64b4..271254d1ae 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h @@ -33,6 +33,8 @@ struct ClientKeyset { static ClientKeyset fromProto(const Message &proto); + static ClientKeyset fromProto(concreteprotocol::ClientKeyset::Reader reader); + Message toProto() const; }; @@ -43,6 +45,7 @@ struct ServerKeyset { static ServerKeyset fromProto(const Message &proto); + static ServerKeyset fromProto(concreteprotocol::ServerKeyset::Reader reader); Message toProto() const; }; @@ -73,6 +76,7 @@ struct Keyset { : server(server), client(client) {} static Keyset fromProto(const Message &proto); + static Keyset fromProto(concreteprotocol::Keyset::Reader reader); Message toProto() const; }; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h index bf11a32a33..c897be04f4 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h @@ -69,7 +69,8 @@ template struct Message { message = regionBuilder->initRoot(); } - Message(const typename MessageType::Reader &reader) : message(nullptr) { + explicit Message(const typename MessageType::Reader &reader) + : message(nullptr) { regionBuilder = new capnp::MallocMessageBuilder( std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE), capnp::AllocationStrategy::FIXED_SIZE); @@ -308,7 +309,12 @@ vectorToProtoPayload(const std::vector &input) { template std::vector protoPayloadToVector(const Message &input) { - auto payloadData = input.asReader().getData(); + return protoPayloadToVector(input.asReader()); +} + +template +std::vector protoPayloadToVector(concreteprotocol::Payload::Reader reader) { + auto payloadData = reader.getData(); auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T); size_t totalPayloadSize = 0; for (auto blob : payloadData) { @@ -331,7 +337,13 @@ protoPayloadToVector(const Message &input) { template std::shared_ptr> protoPayloadToSharedVector(const Message &input) { - auto payloadData = input.asReader().getData(); + return protoPayloadToSharedVector(input.asReader()); +} + +template +std::shared_ptr> +protoPayloadToSharedVector(concreteprotocol::Payload::Reader reader) { + auto payloadData = reader.getData(); size_t elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T); size_t totalPayloadSize = 0; for (auto blob : payloadData) { @@ -353,6 +365,8 @@ protoPayloadToSharedVector(const Message &input) { /// dimensions. std::vector protoShapeToDimensions(const Message &shape); +std::vector +protoShapeToDimensions(concreteprotocol::Shape::Reader reader); /// Helper function turning a protocol `Shape` object into a vector of /// dimensions. diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h index 33dfc477e2..1d59ba84fb 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h @@ -198,6 +198,7 @@ struct Value { bool isCompatibleWithShape(const Message &shape) const; + bool isCompatibleWithShape(concreteprotocol::Shape::Reader reader) const; bool isScalar() const; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h index 26945db0e4..95b73578ec 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h @@ -81,15 +81,15 @@ class TestProgram { } OUTCOME_TRY(auto lib, getLibrary()); OUTCOME_TRY(auto programInfo, lib.getProgramInfo()); + auto keysetInfo = + (Message)programInfo.asReader() + .getKeyset(); if (tryCache) { OUTCOME_TRY(keyset, getTestKeySetCachePtr()->getKeyset( - programInfo.asReader().getKeyset(), secretSeed, - encryptionSeed)); + keysetInfo, secretSeed, encryptionSeed)); } else { auto encryptionCsprng = csprng::EncryptionCSPRNG(encryptionSeed); auto secretCsprng = csprng::SecretCSPRNG(secretSeed); - Message keysetInfo = - programInfo.asReader().getKeyset(); keyset = Keyset(keysetInfo, secretCsprng, encryptionCsprng); } return outcome::success(); diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 28d5f9ce2f..d699427966 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1107,7 +1107,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](KeysetInfo &keysetInfo) { auto secretKeys = std::vector(); for (auto key : keysetInfo.asReader().getLweSecretKeys()) { - secretKeys.push_back(LweSecretKeyParam{key}); + secretKeys.push_back(LweSecretKeyParam{ + (Message)key}); } return secretKeys; }, @@ -1117,7 +1118,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](KeysetInfo &keysetInfo) { auto bootstrapKeys = std::vector(); for (auto key : keysetInfo.asReader().getLweBootstrapKeys()) { - bootstrapKeys.push_back(BootstrapKeyParam{key}); + bootstrapKeys.push_back(BootstrapKeyParam{ + (Message)key}); } return bootstrapKeys; }, @@ -1127,7 +1129,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](KeysetInfo &keysetInfo) { auto keyswitchKeys = std::vector(); for (auto key : keysetInfo.asReader().getLweKeyswitchKeys()) { - keyswitchKeys.push_back(KeyswitchKeyParam{key}); + keyswitchKeys.push_back(KeyswitchKeyParam{ + (Message)key}); } return keyswitchKeys; }, @@ -1137,7 +1140,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](KeysetInfo &keysetInfo) { auto packingKeyswitchKeys = std::vector(); for (auto key : keysetInfo.asReader().getPackingKeyswitchKeys()) { - packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{key}); + packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{ + (Message)key}); } return packingKeyswitchKeys; }, @@ -1220,13 +1224,13 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def( "get_type_info", [](GateInfo &gate) -> TypeInfo { - return {gate.asReader().getTypeInfo()}; + return {(TypeInfo)gate.asReader().getTypeInfo()}; }, "Return the type associated to the gate.") .def( "get_raw_info", [](GateInfo &gate) -> RawInfo { - return {gate.asReader().getRawInfo()}; + return {(RawInfo)gate.asReader().getRawInfo()}; }, "Return the raw type associated to the gate.") .doc() = "Informations describing a circuit gate (input or output)."; @@ -1247,7 +1251,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](CircuitInfo &circuit) -> std::vector { auto output = std::vector(); for (auto gate : circuit.asReader().getInputs()) { - output.push_back({gate}); + output.push_back({(Message)gate}); } return output; }, @@ -1257,7 +1261,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](CircuitInfo &circuit) -> std::vector { auto output = std::vector(); for (auto gate : circuit.asReader().getOutputs()) { - output.push_back({gate}); + output.push_back({(Message)gate}); } return output; }, @@ -1415,7 +1419,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def( "get_keyset_info", [](ProgramInfo &programInfo) -> KeysetInfo { - return programInfo.programInfo.asReader().getKeyset(); + return (KeysetInfo)programInfo.programInfo.asReader().getKeyset(); }, "Return the keyset info associated to the program.") .def( @@ -1424,7 +1428,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( auto output = std::vector(); for (auto circuit : programInfo.programInfo.asReader().getCircuits()) { - output.push_back(circuit); + output.push_back((CircuitInfo)circuit); } return output; }, @@ -1435,7 +1439,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( for (auto circuit : programInfo.programInfo.asReader().getCircuits()) { if (circuit.getName() == name) { - return circuit; + return (CircuitInfo)circuit; } } throw std::runtime_error("couldn't find circuit."); @@ -1552,7 +1556,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( throw std::runtime_error("Failed to deserialize server keyset." + maybeError.as_failure().error().mesg); } - return ServerKeyset::fromProto(serverKeysetProto); + return ServerKeyset::fromProto(serverKeysetProto.asReader()); }, "Deserialize a ServerKeyset from bytes.", arg("bytes")) .def( @@ -1604,7 +1608,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( GET_OR_THROW_RESULT( Keyset keyset, (*cache).getKeyset( - programInfo.programInfo.asReader().getKeyset(), + (KeysetInfo)programInfo.programInfo.asReader() + .getKeyset(), secretSeed, encryptionSeed, initialLweSecretKeys.value())); return std::make_unique(std::move(keyset)); @@ -1612,9 +1617,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( ::concretelang::csprng::SecretCSPRNG secCsprng(secretSeed); ::concretelang::csprng::EncryptionCSPRNG encCsprng( encryptionSeed); - auto keyset = - Keyset(programInfo.programInfo.asReader().getKeyset(), - secCsprng, encCsprng, initialLweSecretKeys.value()); + auto keyset = Keyset( + (KeysetInfo)programInfo.programInfo.asReader().getKeyset(), + secCsprng, encCsprng, initialLweSecretKeys.value()); return std::make_unique(std::move(keyset)); } }), @@ -1652,7 +1657,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( throw std::runtime_error("Failed to deserialize keyset." + maybeError.as_failure().error().mesg); } - auto keyset = Keyset::fromProto(keysetProto); + auto keyset = Keyset::fromProto(std::move(keysetProto)); return std::make_unique(std::move(keyset)); }, "Deserialize a Keyset from a file.", arg("path")) @@ -1670,6 +1675,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return pybind11::bytes(keySetSerialize(keySet)); }, "Serialize a Keyset to bytes.") + .def( + "serialize_to_file", + [](Keyset &keySet, const std::string path) { + std::ofstream ofs; + ofs.open(path); + if (!ofs.good()) { + throw std::runtime_error("Failed to open keyset file " + path); + } + auto keysetProto = keySet.toProto(); + auto maybeBuffer = keysetProto.writeBinaryToOstream(ofs); + if (maybeBuffer.has_failure()) { + throw std::runtime_error("Failed to serialize keys."); + } + }, + "Serialize a Keyset to bytes.") .def( "serialize_lwe_secret_key_as_glwe", [](Keyset &keyset, size_t keyIndex, size_t glwe_dimension, @@ -2034,8 +2054,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( GET_OR_THROW_RESULT(auto pi, library.getProgramInfo()); GET_OR_THROW_RESULT( - auto result, ServerProgram::load(pi.asReader(), sharedLibPath, - useSimulation)); + auto result, + ServerProgram::load( + (Message)pi.asReader(), + sharedLibPath, useSimulation)); return result; }), arg("library"), arg("use_simulation")) @@ -2061,7 +2083,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( throw std::runtime_error("Unknown position."); } auto info = circuit.getCircuitInfo().asReader().getInputs()[pos]; - auto typeTransformer = getPythonTypeTransformer(info); + auto typeTransformer = getPythonTypeTransformer((GateInfo)info); GET_OR_THROW_RESULT( auto ok, circuit.prepareInput(typeTransformer(arg), pos)); return ok; @@ -2084,7 +2106,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( throw std::runtime_error("Unknown position."); } auto info = circuit.getCircuitInfo().asReader().getInputs()[pos]; - auto typeTransformer = getPythonTypeTransformer(info); + auto typeTransformer = getPythonTypeTransformer((GateInfo)info); GET_OR_THROW_RESULT(auto ok, circuit.simulatePrepareInput( typeTransformer(arg), pos)); return ok; diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index f6954da889..cf5413e0d3 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -49,14 +49,17 @@ ClientCircuit::create(const Message &info, InputTransformer transformer; if (gateInfo.getTypeInfo().hasIndex()) { OUTCOME_TRY(transformer, - TransformerFactory::getIndexInputTransformer(gateInfo)); + TransformerFactory::getIndexInputTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasPlaintext()) { OUTCOME_TRY(transformer, - TransformerFactory::getPlaintextInputTransformer(gateInfo)); + TransformerFactory::getPlaintextInputTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasLweCiphertext()) { OUTCOME_TRY(transformer, TransformerFactory::getLweCiphertextInputTransformer( - keyset, gateInfo, csprng, useSimulation)); + keyset, (Message)gateInfo, + csprng, useSimulation)); } else { return StringError("Malformed input gate info."); } @@ -69,14 +72,17 @@ ClientCircuit::create(const Message &info, OutputTransformer transformer; if (gateInfo.getTypeInfo().hasIndex()) { OUTCOME_TRY(transformer, - TransformerFactory::getIndexOutputTransformer(gateInfo)); + TransformerFactory::getIndexOutputTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasPlaintext()) { OUTCOME_TRY(transformer, - TransformerFactory::getPlaintextOutputTransformer(gateInfo)); + TransformerFactory::getPlaintextOutputTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasLweCiphertext()) { OUTCOME_TRY(transformer, TransformerFactory::getLweCiphertextOutputTransformer( - keyset, gateInfo, useSimulation)); + keyset, (Message)gateInfo, + useSimulation)); } else { return StringError("Malformed output gate info."); } @@ -161,7 +167,9 @@ Result ClientProgram::createEncrypted( ClientProgram output; for (auto circuitInfo : info.asReader().getCircuits()) { OUTCOME_TRY(const ClientCircuit clientCircuit, - ClientCircuit::createEncrypted(circuitInfo, keyset, csprng)); + ClientCircuit::createEncrypted( + (Message)circuitInfo, keyset, + csprng)); output.circuits.push_back(clientCircuit); } return output; @@ -172,8 +180,10 @@ Result ClientProgram::createSimulated( std::shared_ptr csprng) { ClientProgram output; for (auto circuitInfo : info.asReader().getCircuits()) { - OUTCOME_TRY(const ClientCircuit clientCircuit, - ClientCircuit::createSimulated(circuitInfo, csprng)); + OUTCOME_TRY( + const ClientCircuit clientCircuit, + ClientCircuit::createSimulated( + (Message)circuitInfo, csprng)); output.circuits.push_back(clientCircuit); } return output; diff --git a/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp b/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp index d9889dbf0e..cdd9f2c20f 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp @@ -64,11 +64,14 @@ LweSecretKey::LweSecretKey(Message info, LweSecretKey LweSecretKey::fromProto(const Message &proto) { + return fromProto(proto.asReader()); +} - auto info = - Message(proto.asReader().getInfo()); - auto vector = - protoPayloadToSharedVector(proto.asReader().getPayload()); +LweSecretKey +LweSecretKey::fromProto(concreteprotocol::LweSecretKey::Reader reader) { + + auto info = Message(reader.getInfo()); + auto vector = protoPayloadToSharedVector(reader.getPayload()); return LweSecretKey(vector, info); } @@ -137,11 +140,14 @@ LweBootstrapKey::LweBootstrapKey( }; LweBootstrapKey LweBootstrapKey::fromProto( - const Message &proto) { - auto info = Message( - proto.asReader().getInfo()); - auto vector = - protoPayloadToSharedVector(proto.asReader().getPayload()); + const Message &key) { + return fromProto(key.asReader()); +} + +LweBootstrapKey +LweBootstrapKey::fromProto(concreteprotocol::LweBootstrapKey::Reader reader) { + auto info = Message(reader.getInfo()); + auto vector = protoPayloadToSharedVector(reader.getPayload()); LweBootstrapKey key(info); switch (info.asReader().getCompression()) { case concreteprotocol::Compression::NONE: @@ -258,10 +264,13 @@ LweKeyswitchKey::LweKeyswitchKey( LweKeyswitchKey LweKeyswitchKey::fromProto( const Message &proto) { - auto info = Message( - proto.asReader().getInfo()); - auto vector = - protoPayloadToSharedVector(proto.asReader().getPayload()); + return fromProto(proto.asReader()); +} + +LweKeyswitchKey +LweKeyswitchKey::fromProto(concreteprotocol::LweKeyswitchKey::Reader reader) { + auto info = Message(reader.getInfo()); + auto vector = protoPayloadToSharedVector(reader.getPayload()); LweKeyswitchKey key(info); switch (info.asReader().getCompression()) { case concreteprotocol::Compression::NONE: @@ -362,10 +371,14 @@ PackingKeyswitchKey::PackingKeyswitchKey( PackingKeyswitchKey PackingKeyswitchKey::fromProto( const Message &proto) { - auto info = Message( - proto.asReader().getInfo()); - auto vector = - protoPayloadToSharedVector(proto.asReader().getPayload()); + return fromProto(proto.asReader()); +} + +PackingKeyswitchKey PackingKeyswitchKey::fromProto( + concreteprotocol::PackingKeyswitchKey::Reader reader) { + auto info = + Message(reader.getInfo()); + auto vector = protoPayloadToSharedVector(reader.getPayload()); return PackingKeyswitchKey(vector, info); } diff --git a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp index 3096e64a0d..8a5964ed27 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp @@ -43,8 +43,13 @@ namespace keysets { ClientKeyset ClientKeyset::fromProto(const Message &proto) { + return fromProto(proto.asReader()); +} + +ClientKeyset +ClientKeyset::fromProto(concreteprotocol::ClientKeyset::Reader reader) { auto output = ClientKeyset(); - for (auto skProto : proto.asReader().getLweSecretKeys()) { + for (auto skProto : reader.getLweSecretKeys()) { output.lweSecretKeys.push_back(LweSecretKey::fromProto(skProto)); } @@ -64,16 +69,21 @@ Message ClientKeyset::toProto() const { ServerKeyset ServerKeyset::fromProto(const Message &proto) { + return fromProto(proto.asReader()); +} + +ServerKeyset +ServerKeyset::fromProto(concreteprotocol::ServerKeyset::Reader reader) { auto output = ServerKeyset(); - for (auto bskProto : proto.asReader().getLweBootstrapKeys()) { + for (auto bskProto : reader.getLweBootstrapKeys()) { output.lweBootstrapKeys.push_back(LweBootstrapKey::fromProto(bskProto)); } - for (auto kskProto : proto.asReader().getLweKeyswitchKeys()) { + for (auto kskProto : reader.getLweKeyswitchKeys()) { output.lweKeyswitchKeys.push_back(LweKeyswitchKey::fromProto(kskProto)); } - for (auto pkskProto : proto.asReader().getPackingKeyswitchKeys()) { + for (auto pkskProto : reader.getPackingKeyswitchKeys()) { output.packingKeyswitchKeys.push_back( PackingKeyswitchKey::fromProto(pkskProto)); } @@ -117,38 +127,65 @@ Keyset::Keyset(const Message &info, client.lweSecretKeys.push_back(lweSk); } else { // generate new key - client.lweSecretKeys.push_back(LweSecretKey(keyInfo, secretCsprng)); + client.lweSecretKeys.push_back(LweSecretKey( + (Message)keyInfo, secretCsprng)); } } for (auto keyInfo : info.asReader().getLweBootstrapKeys()) { server.lweBootstrapKeys.push_back(LweBootstrapKey( - keyInfo, client.lweSecretKeys[keyInfo.getInputId()], + (Message)keyInfo, + client.lweSecretKeys[keyInfo.getInputId()], client.lweSecretKeys[keyInfo.getOutputId()], encryptionCsprng)); } for (auto keyInfo : info.asReader().getLweKeyswitchKeys()) { server.lweKeyswitchKeys.push_back(LweKeyswitchKey( - keyInfo, client.lweSecretKeys[keyInfo.getInputId()], + (Message)keyInfo, + client.lweSecretKeys[keyInfo.getInputId()], client.lweSecretKeys[keyInfo.getOutputId()], encryptionCsprng)); } for (auto keyInfo : info.asReader().getPackingKeyswitchKeys()) { server.packingKeyswitchKeys.push_back(PackingKeyswitchKey( - keyInfo, client.lweSecretKeys[keyInfo.getInputId()], + (Message)keyInfo, + client.lweSecretKeys[keyInfo.getInputId()], client.lweSecretKeys[keyInfo.getOutputId()], encryptionCsprng)); } } Keyset Keyset::fromProto(const Message &proto) { - auto server = ServerKeyset::fromProto(proto.asReader().getServer()); - auto client = ClientKeyset::fromProto(proto.asReader().getClient()); + return fromProto(proto.asReader()); +} + +Keyset Keyset::fromProto(concreteprotocol::Keyset::Reader reader) { + auto server = ServerKeyset::fromProto(reader.getServer()); + auto client = ClientKeyset::fromProto(reader.getClient()); return {server, client}; } Message Keyset::toProto() const { auto output = Message(); - auto serverProto = server.toProto(); + // we inlined call to server.toProto() to avoid a single big copy of the + // server keyset. With this, we only do copies of individual keys. + auto serverKeyset = output.asBuilder().initServer(); + serverKeyset.initLweBootstrapKeys(server.lweBootstrapKeys.size()); + for (size_t i = 0; i < server.lweBootstrapKeys.size(); i++) { + serverKeyset.getLweBootstrapKeys().setWithCaveats( + i, server.lweBootstrapKeys[i].toProto().asReader()); + } + + serverKeyset.initLweKeyswitchKeys(server.lweKeyswitchKeys.size()); + for (size_t i = 0; i < server.lweKeyswitchKeys.size(); i++) { + serverKeyset.getLweKeyswitchKeys().setWithCaveats( + i, server.lweKeyswitchKeys[i].toProto().asReader()); + } + + serverKeyset.initPackingKeyswitchKeys(server.packingKeyswitchKeys.size()); + for (size_t i = 0; i < server.packingKeyswitchKeys.size(); i++) { + serverKeyset.getPackingKeyswitchKeys().setWithCaveats( + i, server.packingKeyswitchKeys[i].toProto().asReader()); + } + // client serialization is not inlined as keys aren't that big auto clientProto = client.toProto(); - output.asBuilder().setServer(serverProto.asReader()); output.asBuilder().setClient(clientProto.asReader()); return output; } diff --git a/compilers/concrete-compiler/compiler/lib/Common/Protocol.cpp b/compilers/concrete-compiler/compiler/lib/Common/Protocol.cpp index f2e94f576e..e718f0be4c 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Protocol.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Protocol.cpp @@ -17,8 +17,13 @@ namespace protocol { /// dimensions. std::vector protoShapeToDimensions(const Message &shape) { + return protoShapeToDimensions(shape.asReader()); +} + +std::vector +protoShapeToDimensions(concreteprotocol::Shape::Reader reader) { auto output = std::vector(); - for (auto dim : shape.asReader().getDimensions()) { + for (auto dim : reader.getDimensions()) { output.push_back(dim); } return output; diff --git a/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp b/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp index ce121fbfb2..77b6c18ebb 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp @@ -220,7 +220,7 @@ updateGateInfoAccordingValue(Message &gate, auto lweSize = gateCiphertext.getEncryption().getLweDimension() + 1; gateDimensions.set(gateDimensions.size() - 1, lweSize); concreteShapeDimensions.set(concreteShapeDimensions.size() - 1, lweSize); - return gateBuilder.asReader(); + return (Message)gateBuilder.asReader(); } if (gateCompression == concreteprotocol::Compression::NONE && valueCompression == concreteprotocol::Compression::SEED) { @@ -232,7 +232,7 @@ updateGateInfoAccordingValue(Message &gate, concreteprotocol::Compression::SEED); gateDimensions.set(gateDimensions.size() - 1, 3); concreteShapeDimensions.set(concreteShapeDimensions.size() - 1, 3); - return gateBuilder.asReader(); + return (Message)gateBuilder.asReader(); } return gate; } @@ -837,11 +837,13 @@ Result TransformerFactory::getLweCiphertextInputTransformer( .getEncoding() .hasInteger()) { OUTCOME_TRY(encodingTransformer, - getIntegerEncodingTransformer(gateInfo.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncoding() - .getInteger())); + getIntegerEncodingTransformer( + (Message) + gateInfo.asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncoding() + .getInteger())); } else { return StringError("Malformed gate info"); } @@ -850,28 +852,35 @@ Result TransformerFactory::getLweCiphertextInputTransformer( Transformer encryptionTransformer; if (useSimulation) { OUTCOME_TRY(encryptionTransformer, - getEncryptionSimulationTransformer(gateInfo.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncryption(), - csprng)); + getEncryptionSimulationTransformer( + (Message) + gateInfo.asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncryption(), + csprng)); } else { auto compression = gateInfo.asReader().getTypeInfo().getLweCiphertext().getCompression(); if (compression == concreteprotocol::Compression::NONE) { OUTCOME_TRY(encryptionTransformer, - getEncryptionTransformer(keyset, - gateInfo.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncryption(), - csprng)); + getEncryptionTransformer( + keyset, + (Message) + gateInfo.asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncryption(), + csprng)); } else if (compression == concreteprotocol::Compression::SEED) { - OUTCOME_TRY(encryptionTransformer, - getSeededEncryptionTransformer(keyset, gateInfo.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncryption())); + OUTCOME_TRY( + encryptionTransformer, + getSeededEncryptionTransformer( + keyset, (Message) + gateInfo.asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncryption())); } else { return StringError( "Only none compression is currently supported for lwe ciphertext " @@ -947,8 +956,10 @@ Result TransformerFactory::getLweCiphertextArgTransformer( /// Generating the decompression transformer. auto lweCiphertextInfo = gateInfo.asReader().getTypeInfo().getLweCiphertext(); OUTCOME_TRY(auto decompressionTransformer, - getDecompressionTransformer(lweCiphertextInfo.getEncryption(), - useSimulation)); + getDecompressionTransformer( + (Message) + lweCiphertextInfo.getEncryption(), + useSimulation)); // Generating the verifier. TransportValueVerifier verify; @@ -1021,7 +1032,11 @@ Result TransformerFactory::getLweCiphertextOutputTransformer( /// Generating the decompression transformer. auto encryptionInfo = - gateInfo.asReader().getTypeInfo().getLweCiphertext().getEncryption(); + (Message)gateInfo + .asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncryption(); OUTCOME_TRY(auto decompressionTransformer, getDecompressionTransformer(encryptionInfo, useSimulation)); @@ -1031,10 +1046,7 @@ Result TransformerFactory::getLweCiphertextOutputTransformer( OUTCOME_TRY(decryptionTransformer, getDecryptionSimulationTransformer()); } else { OUTCOME_TRY(decryptionTransformer, - getDecryptionTransformer(keyset, gateInfo.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncryption())); + getDecryptionTransformer(keyset, encryptionInfo)); } /// Generating the decoding transformer. @@ -1051,11 +1063,13 @@ Result TransformerFactory::getLweCiphertextOutputTransformer( .getEncoding() .hasInteger()) { OUTCOME_TRY(decodingTransformer, - getIntegerDecodingTransformer(gateInfo.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncoding() - .getInteger())); + getIntegerDecodingTransformer( + (Message) + gateInfo.asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncoding() + .getInteger())); } else { return StringError("Malformed gate info"); } diff --git a/compilers/concrete-compiler/compiler/lib/Common/Values.cpp b/compilers/concrete-compiler/compiler/lib/Common/Values.cpp index 43ee7974ba..0f2efeb53f 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Values.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Values.cpp @@ -173,12 +173,17 @@ size_t Value::getLength() const { bool Value::isCompatibleWithShape( const Message &shape) const { + return isCompatibleWithShape(shape.asReader()); +} + +bool Value::isCompatibleWithShape( + concreteprotocol::Shape::Reader reader) const { auto dimensions = getDimensions(); - if ((uint32_t)shape.asReader().getDimensions().size() != dimensions.size()) { + if ((uint32_t)reader.getDimensions().size() != dimensions.size()) { return false; } for (uint32_t i = 0; i < dimensions.size(); i++) { - if (shape.asReader().getDimensions()[i] != dimensions[i]) { + if (reader.getDimensions()[i] != dimensions[i]) { return false; } } diff --git a/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp b/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp index 5bba83f662..9c50d504fc 100644 --- a/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp @@ -494,14 +494,17 @@ Result ServerCircuit::fromDynamicModule( ArgTransformer transformer; if (gateInfo.getTypeInfo().hasIndex()) { OUTCOME_TRY(transformer, - TransformerFactory::getIndexArgTransformer(gateInfo)); + TransformerFactory::getIndexArgTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasPlaintext()) { OUTCOME_TRY(transformer, - TransformerFactory::getPlaintextArgTransformer(gateInfo)); + TransformerFactory::getPlaintextArgTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasLweCiphertext()) { - OUTCOME_TRY(transformer, - TransformerFactory::getLweCiphertextArgTransformer( - gateInfo, useSimulation)); + OUTCOME_TRY( + transformer, + TransformerFactory::getLweCiphertextArgTransformer( + (Message)gateInfo, useSimulation)); } else { return StringError("Malformed input gate info."); } @@ -514,14 +517,17 @@ Result ServerCircuit::fromDynamicModule( ReturnTransformer transformer; if (gateInfo.getTypeInfo().hasIndex()) { OUTCOME_TRY(transformer, - TransformerFactory::getIndexReturnTransformer(gateInfo)); + TransformerFactory::getIndexReturnTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasPlaintext()) { OUTCOME_TRY(transformer, - TransformerFactory::getPlaintextReturnTransformer(gateInfo)); + TransformerFactory::getPlaintextReturnTransformer( + (Message)gateInfo)); } else if (gateInfo.getTypeInfo().hasLweCiphertext()) { - OUTCOME_TRY(transformer, - TransformerFactory::getLweCiphertextReturnTransformer( - gateInfo, useSimulation)); + OUTCOME_TRY( + transformer, + TransformerFactory::getLweCiphertextReturnTransformer( + (Message)gateInfo, useSimulation)); } else { return StringError("Malformed input gate info."); } @@ -535,14 +541,16 @@ Result ServerCircuit::fromDynamicModule( output.argRawSize = 0; for (auto gateInfo : circuitInfo.asReader().getInputs()) { - auto descriptorSize = getGateDescriptionSize(gateInfo, useSimulation); + auto descriptorSize = getGateDescriptionSize( + (Message)gateInfo, useSimulation); output.argDescriptorSizes.push_back(descriptorSize); output.argRawSize += descriptorSize; } output.returnRawSize = 0; for (auto gateInfo : circuitInfo.asReader().getOutputs()) { - auto descriptorSize = getGateDescriptionSize(gateInfo, useSimulation); + auto descriptorSize = getGateDescriptionSize( + (Message)gateInfo, useSimulation); output.returnDescriptorSizes.push_back(descriptorSize); output.returnRawSize += descriptorSize; } @@ -605,9 +613,12 @@ void ServerCircuit::invoke(const ServerKeyset &serverKeyset) { for (unsigned int i = 0; i < circuitInfo.asReader().getOutputs().size(); i++) { // We read the descriptor from the _returnRaws via the maps. - size_t precision = - getGateIntegerPrecision(circuitInfo.asReader().getOutputs()[i]); - bool isSigned = getGateIsSigned(circuitInfo.asReader().getOutputs()[i]); + size_t precision = getGateIntegerPrecision( + (Message)circuitInfo.asReader() + .getOutputs()[i]); + bool isSigned = getGateIsSigned( + (Message)circuitInfo.asReader() + .getOutputs()[i]); InvocationDescriptor descriptor = InvocationDescriptor::fromU64s(_returnRawMaps[i], precision, isSigned); // We generate a value from the descriptor which we store in the @@ -631,7 +642,8 @@ ServerProgram::load(const Message &programInfo, for (auto circuitInfo : programInfo.asReader().getCircuits()) { OUTCOME_TRY(auto serverCircuit, ServerCircuit::fromDynamicModule( - circuitInfo, sharedDynamicModule, useSimulation)); + (Message)circuitInfo, + sharedDynamicModule, useSimulation)); serverCircuits.push_back(serverCircuit); } output.serverCircuits = serverCircuits; diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp index dc8e6229a7..f6e2a13d65 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp @@ -18,17 +18,14 @@ namespace concretelang { void CircuitCompilationFeedback::fillFromCircuitInfo( concreteprotocol::CircuitInfo::Reader circuitInfo) { - auto computeGateSize = - [&](const Message &gateInfo) { - unsigned int nElements = 1; - for (auto dimension : - gateInfo.asReader().getRawInfo().getShape().getDimensions()) { - nElements *= dimension; - } - unsigned int gateScalarSize = - gateInfo.asReader().getRawInfo().getIntegerPrecision() / 8; - return nElements * gateScalarSize; - }; + auto computeGateSize = [&](const concreteprotocol::GateInfo::Reader reader) { + unsigned int nElements = 1; + for (auto dimension : reader.getRawInfo().getShape().getDimensions()) { + nElements *= dimension; + } + unsigned int gateScalarSize = reader.getRawInfo().getIntegerPrecision() / 8; + return nElements * gateScalarSize; + }; // Compute the size of inputs totalInputsSize = 0; for (auto gateInfo : circuitInfo.getInputs()) { diff --git a/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp b/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp index c2554b9abb..b6fb754e56 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp @@ -326,7 +326,9 @@ extractCircuitInfo(mlir::func::FuncOp funcOp, auto compression = compressInputCiphertexts ? concreteprotocol::Compression::SEED : concreteprotocol::Compression::NONE; - auto maybeGate = generateGate(ty, encoding, curve, compression); + auto maybeGate = + generateGate(ty, (Message)encoding, + curve, compression); if (!maybeGate) { return maybeGate.takeError(); } @@ -336,7 +338,9 @@ extractCircuitInfo(mlir::func::FuncOp funcOp, auto ty = funcType.getResult(i); auto encoding = encodings.getOutputs()[i]; auto compression = concreteprotocol::Compression::NONE; - auto maybeGate = generateGate(ty, encoding, curve, compression); + auto maybeGate = + generateGate(ty, (Message)encoding, + curve, compression); if (!maybeGate) { return maybeGate.takeError(); } diff --git a/frontends/concrete-python/concrete/fhe/compilation/keys.py b/frontends/concrete-python/concrete/fhe/compilation/keys.py index 76356ca7c6..9ef2a698c5 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/keys.py +++ b/frontends/concrete-python/concrete/fhe/compilation/keys.py @@ -115,7 +115,7 @@ def save(self, location: Union[str, Path]): message = f"Unable to save keys to {location} because it already exists" raise ValueError(message) - location.write_bytes(self.serialize()) + self.serialize_to_file(location) def load(self, location: Union[str, Path]): """ @@ -171,6 +171,8 @@ def serialize(self) -> bytes: Serialize keys into bytes. Serialized keys are not encrypted, so be careful how you store/transfer them! + `serialize_to_file` is supposed to be more performant as it avoid copying the buffer + between the Compiler and the Frontend. Returns: bytes: @@ -184,6 +186,23 @@ def serialize(self) -> bytes: serialized_keyset = self._keyset.serialize() return serialized_keyset + def serialize_to_file(self, path: Path): + """ + Serialize keys into a file. + + Serialized keys are not encrypted, so be careful how you store/transfer them! + This is supposed to be more performant than `serialize` as it avoid copying the buffer + between the Compiler and the Frontend. + + Args: + path (Path): where to save serialized keys + """ + if self._keyset is None: + message = "Keys cannot be serialized before they are generated" + raise RuntimeError(message) + + self._keyset.serialize_to_file(str(path)) + @staticmethod def deserialize(serialized_keys: Union[Path, bytes]) -> "Keys": """ diff --git a/frontends/concrete-python/tests/compilation/test_keys.py b/frontends/concrete-python/tests/compilation/test_keys.py index 94d56ef858..2a24868075 100644 --- a/frontends/concrete-python/tests/compilation/test_keys.py +++ b/frontends/concrete-python/tests/compilation/test_keys.py @@ -175,6 +175,13 @@ def f(x): expected_message = "Keys cannot be serialized before they are generated" helpers.check_str(expected_message, str(excinfo.value)) + with pytest.raises(RuntimeError) as excinfo: + # path doesn't matter as it will fail + circuit.keys.serialize_to_file(Path("_keys_file")) + + expected_message = "Keys cannot be serialized before they are generated" + helpers.check_str(expected_message, str(excinfo.value)) + def test_keys_generate_manual_seed(helpers): """