From 88e90a775fd05e689f55620c1e5440210909fc3a Mon Sep 17 00:00:00 2001 From: Bourgerie Quentin Date: Tue, 14 May 2024 09:09:56 +0200 Subject: [PATCH] fix(compiler): Share the decompress state instead of try to avoid copies This partially reverts commit 081e8b7b74d31909a3aa9a59099004a1ceeeae3f. --- .../compiler/include/concretelang/Common/Keys.h | 12 ++++++------ .../compiler/include/concretelang/Runtime/context.h | 6 ++++-- .../include/concretelang/ServerLib/ServerLib.h | 4 ++-- .../lib/Bindings/Python/CompilerAPIModule.cpp | 2 +- .../concrete-compiler/compiler/lib/Common/Keys.cpp | 12 ++++++------ .../compiler/lib/Runtime/context.cpp | 2 +- .../compiler/lib/ServerLib/ServerLib.cpp | 4 ++-- 7 files changed, 22 insertions(+), 20 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h index 59a38c3b27..1f35c6f373 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h @@ -89,7 +89,7 @@ class LweBootstrapKey { Message info) : seededBuffer(std::make_shared>()), buffer(buffer), info(info), decompress_mutext(std::make_shared()), - decompressed(false){}; + decompressed(std::make_shared(false)){}; /// @brief Initialize the key from the protocol message. static LweBootstrapKey @@ -111,7 +111,7 @@ class LweBootstrapKey { : seededBuffer(std::make_shared>()), buffer(std::make_shared>()), info(info), decompress_mutext(std::make_shared()), - decompressed(false){}; + decompressed(std::make_shared(false)){}; LweBootstrapKey() = delete; /// @brief The buffer of the seeded key if needed. @@ -127,7 +127,7 @@ class LweBootstrapKey { std::shared_ptr decompress_mutext; /// @brief A boolean that indicates if the decompression is done or not - bool decompressed; + std::shared_ptr decompressed; }; class LweKeyswitchKey { @@ -141,7 +141,7 @@ class LweKeyswitchKey { Message info) : seededBuffer(std::make_shared>()), buffer(buffer), info(info), decompress_mutext(std::make_shared()), - decompressed(false){}; + decompressed(std::make_shared(false)){}; /// @brief Initialize the key from the protocol message. static LweKeyswitchKey @@ -163,7 +163,7 @@ class LweKeyswitchKey { : seededBuffer(std::make_shared>()), buffer(std::make_shared>()), info(info), decompress_mutext(std::make_shared()), - decompressed(false){}; + decompressed(std::make_shared(false)){}; /// @brief The buffer of the seeded key if needed. std::shared_ptr> seededBuffer; @@ -178,7 +178,7 @@ class LweKeyswitchKey { std::shared_ptr decompress_mutext; /// @brief A boolean that indicates if the decompression is done or not - bool decompressed; + std::shared_ptr decompressed; }; class PackingKeyswitchKey { diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h index 2fdc67d5a5..18cf4bae37 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h @@ -42,7 +42,7 @@ typedef struct FFT { typedef struct RuntimeContext { RuntimeContext() = delete; - RuntimeContext(ServerKeyset &serverKeyset); + RuntimeContext(ServerKeyset serverKeyset); virtual ~RuntimeContext() { #ifdef CONCRETELANG_CUDA_SUPPORT for (int i = 0; i < num_devices; ++i) { @@ -71,8 +71,10 @@ typedef struct RuntimeContext { virtual const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; } + const ServerKeyset getKeys() const { return serverKeyset; } + protected: - ServerKeyset &serverKeyset; + ServerKeyset serverKeyset; std::vector>>> fourier_bootstrap_keys; std::vector ffts; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h index 3706f5b320..476b0bae70 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h @@ -47,7 +47,7 @@ class ServerCircuit { public: /// Call the circuit with public arguments. - Result> call(ServerKeyset &serverKeyset, + Result> call(const ServerKeyset &serverKeyset, std::vector &args); /// Simulate the circuit with public arguments. @@ -65,7 +65,7 @@ class ServerCircuit { std::shared_ptr dynamicModule, bool useSimulation); - void invoke(ServerKeyset &serverKeyset); + void invoke(const ServerKeyset &serverKeyset); Message circuitInfo; bool useSimulation; diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index b5c21ae753..f5d9734756 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1226,7 +1226,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( ::concretelang::clientlib::PublicArguments &publicArguments, ::concretelang::clientlib::EvaluationKeys &evaluationKeys) { SignalGuard signalGuard; - auto &keyset = evaluationKeys.keyset; + auto keyset = evaluationKeys.keyset; auto values = publicArguments.values; GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values)); ::concretelang::clientlib::PublicResult res{output}; diff --git a/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp b/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp index 4cf2bdebe2..d9889dbf0e 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp @@ -189,10 +189,10 @@ void LweBootstrapKey::decompress() { case concreteprotocol::Compression::NONE: return; case concreteprotocol::Compression::SEED: { - if (decompressed) + if (*decompressed) return; const std::lock_guard guard(*decompress_mutext); - if (decompressed) + if (*decompressed) return; auto params = info.asReader().getParams(); buffer->resize(concrete_cpu_bootstrap_key_size_u64( @@ -204,7 +204,7 @@ void LweBootstrapKey::decompress() { buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(), params.getPolynomialSize(), params.getGlweDimension(), params.getLevelCount(), params.getBaseLog(), seed, Parallelism::Rayon); - decompressed = true; + *decompressed = true; return; } default: @@ -309,10 +309,10 @@ void LweKeyswitchKey::decompress() { case concreteprotocol::Compression::NONE: return; case concreteprotocol::Compression::SEED: { - if (decompressed) + if (*decompressed) return; const std::lock_guard guard(*decompress_mutext); - if (decompressed) + if (*decompressed) return; auto params = info.asReader().getParams(); buffer->resize(concrete_cpu_keyswitch_key_size_u64( @@ -324,7 +324,7 @@ void LweKeyswitchKey::decompress() { buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(), params.getOutputLweDimension(), params.getLevelCount(), params.getBaseLog(), seed, Parallelism::Rayon); - decompressed = true; + *decompressed = true; return; } default: diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp index 479a6aa63f..9d0df8d0d0 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp @@ -29,7 +29,7 @@ FFT::~FFT() { } } -RuntimeContext::RuntimeContext(ServerKeyset &serverKeyset) +RuntimeContext::RuntimeContext(ServerKeyset serverKeyset) : serverKeyset(serverKeyset) { // Initialize for each bootstrap key the fourier one diff --git a/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp b/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp index 9140c46f0c..6a750742f1 100644 --- a/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLib.cpp @@ -428,7 +428,7 @@ bool getGateIsSigned(const Message &gateInfo) { } Result> -ServerCircuit::call(ServerKeyset &serverKeyset, +ServerCircuit::call(const ServerKeyset &serverKeyset, std::vector &args) { if (args.size() != argsBuffer.size()) { return StringError("Called circuit with wrong number of arguments"); @@ -542,7 +542,7 @@ Result ServerCircuit::fromDynamicModule( return output; } -void ServerCircuit::invoke(ServerKeyset &serverKeyset) { +void ServerCircuit::invoke(const ServerKeyset &serverKeyset) { // We create a runtime context from the keyset, and place a pointer to it in // the structure.