Skip to content

Commit

Permalink
fix(compiler): Share the decompress state instead of try to avoid copies
Browse files Browse the repository at this point in the history
This partially reverts commit 081e8b7.
  • Loading branch information
BourgerieQuentin committed May 14, 2024
1 parent d5a1bbd commit 88e90a7
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LweBootstrapKey {
Message<concreteprotocol::LweBootstrapKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
info(info), decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};

/// @brief Initialize the key from the protocol message.
static LweBootstrapKey
Expand All @@ -111,7 +111,7 @@ class LweBootstrapKey {
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};
LweBootstrapKey() = delete;

/// @brief The buffer of the seeded key if needed.
Expand All @@ -127,7 +127,7 @@ class LweBootstrapKey {
std::shared_ptr<std::mutex> decompress_mutext;

/// @brief A boolean that indicates if the decompression is done or not
bool decompressed;
std::shared_ptr<bool> decompressed;
};

class LweKeyswitchKey {
Expand All @@ -141,7 +141,7 @@ class LweKeyswitchKey {
Message<concreteprotocol::LweKeyswitchKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
info(info), decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};

/// @brief Initialize the key from the protocol message.
static LweKeyswitchKey
Expand All @@ -163,7 +163,7 @@ class LweKeyswitchKey {
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};

/// @brief The buffer of the seeded key if needed.
std::shared_ptr<std::vector<uint64_t>> seededBuffer;
Expand All @@ -178,7 +178,7 @@ class LweKeyswitchKey {
std::shared_ptr<std::mutex> decompress_mutext;

/// @brief A boolean that indicates if the decompression is done or not
bool decompressed;
std::shared_ptr<bool> decompressed;
};

class PackingKeyswitchKey {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ typedef struct FFT {
typedef struct RuntimeContext {

RuntimeContext() = delete;
RuntimeContext(ServerKeyset &serverKeyset);
RuntimeContext(ServerKeyset serverKeyset);
virtual ~RuntimeContext() {
#ifdef CONCRETELANG_CUDA_SUPPORT
for (int i = 0; i < num_devices; ++i) {
Expand Down Expand Up @@ -71,8 +71,10 @@ typedef struct RuntimeContext {

virtual const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }

const ServerKeyset getKeys() const { return serverKeyset; }

protected:
ServerKeyset &serverKeyset;
ServerKeyset serverKeyset;
std::vector<std::shared_ptr<std::vector<std::complex<double>>>>
fourier_bootstrap_keys;
std::vector<FFT> ffts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ServerCircuit {

public:
/// Call the circuit with public arguments.
Result<std::vector<TransportValue>> call(ServerKeyset &serverKeyset,
Result<std::vector<TransportValue>> call(const ServerKeyset &serverKeyset,
std::vector<TransportValue> &args);

/// Simulate the circuit with public arguments.
Expand All @@ -65,7 +65,7 @@ class ServerCircuit {
std::shared_ptr<DynamicModule> dynamicModule,
bool useSimulation);

void invoke(ServerKeyset &serverKeyset);
void invoke(const ServerKeyset &serverKeyset);

Message<concreteprotocol::CircuitInfo> circuitInfo;
bool useSimulation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
::concretelang::clientlib::PublicArguments &publicArguments,
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
SignalGuard signalGuard;
auto &keyset = evaluationKeys.keyset;
auto keyset = evaluationKeys.keyset;
auto values = publicArguments.values;
GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values));
::concretelang::clientlib::PublicResult res{output};
Expand Down
12 changes: 6 additions & 6 deletions compilers/concrete-compiler/compiler/lib/Common/Keys.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ void LweBootstrapKey::decompress() {
case concreteprotocol::Compression::NONE:
return;
case concreteprotocol::Compression::SEED: {
if (decompressed)
if (*decompressed)
return;
const std::lock_guard<std::mutex> guard(*decompress_mutext);
if (decompressed)
if (*decompressed)
return;
auto params = info.asReader().getParams();
buffer->resize(concrete_cpu_bootstrap_key_size_u64(
Expand All @@ -204,7 +204,7 @@ void LweBootstrapKey::decompress() {
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
params.getPolynomialSize(), params.getGlweDimension(),
params.getLevelCount(), params.getBaseLog(), seed, Parallelism::Rayon);
decompressed = true;
*decompressed = true;
return;
}
default:
Expand Down Expand Up @@ -309,10 +309,10 @@ void LweKeyswitchKey::decompress() {
case concreteprotocol::Compression::NONE:
return;
case concreteprotocol::Compression::SEED: {
if (decompressed)
if (*decompressed)
return;
const std::lock_guard<std::mutex> guard(*decompress_mutext);
if (decompressed)
if (*decompressed)
return;
auto params = info.asReader().getParams();
buffer->resize(concrete_cpu_keyswitch_key_size_u64(
Expand All @@ -324,7 +324,7 @@ void LweKeyswitchKey::decompress() {
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
params.getOutputLweDimension(), params.getLevelCount(),
params.getBaseLog(), seed, Parallelism::Rayon);
decompressed = true;
*decompressed = true;
return;
}
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ FFT::~FFT() {
}
}

RuntimeContext::RuntimeContext(ServerKeyset &serverKeyset)
RuntimeContext::RuntimeContext(ServerKeyset serverKeyset)
: serverKeyset(serverKeyset) {

// Initialize for each bootstrap key the fourier one
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ bool getGateIsSigned(const Message<concreteprotocol::GateInfo> &gateInfo) {
}

Result<std::vector<TransportValue>>
ServerCircuit::call(ServerKeyset &serverKeyset,
ServerCircuit::call(const ServerKeyset &serverKeyset,
std::vector<TransportValue> &args) {
if (args.size() != argsBuffer.size()) {
return StringError("Called circuit with wrong number of arguments");
Expand Down Expand Up @@ -542,7 +542,7 @@ Result<ServerCircuit> ServerCircuit::fromDynamicModule(
return output;
}

void ServerCircuit::invoke(ServerKeyset &serverKeyset) {
void ServerCircuit::invoke(const ServerKeyset &serverKeyset) {

// We create a runtime context from the keyset, and place a pointer to it in
// the structure.
Expand Down

0 comments on commit 88e90a7

Please sign in to comment.