Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(compiler): Share the decompress state instead of try to avoid copies #822

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LweBootstrapKey {
Message<concreteprotocol::LweBootstrapKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
info(info), decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};

/// @brief Initialize the key from the protocol message.
static LweBootstrapKey
Expand All @@ -111,7 +111,7 @@ class LweBootstrapKey {
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};
LweBootstrapKey() = delete;

/// @brief The buffer of the seeded key if needed.
Expand All @@ -127,7 +127,7 @@ class LweBootstrapKey {
std::shared_ptr<std::mutex> decompress_mutext;

/// @brief A boolean that indicates if the decompression is done or not
bool decompressed;
std::shared_ptr<bool> decompressed;
};

class LweKeyswitchKey {
Expand All @@ -141,7 +141,7 @@ class LweKeyswitchKey {
Message<concreteprotocol::LweKeyswitchKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
info(info), decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};

/// @brief Initialize the key from the protocol message.
static LweKeyswitchKey
Expand All @@ -163,7 +163,7 @@ class LweKeyswitchKey {
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
decompressed(std::make_shared<bool>(false)){};

/// @brief The buffer of the seeded key if needed.
std::shared_ptr<std::vector<uint64_t>> seededBuffer;
Expand All @@ -178,7 +178,7 @@ class LweKeyswitchKey {
std::shared_ptr<std::mutex> decompress_mutext;

/// @brief A boolean that indicates if the decompression is done or not
bool decompressed;
std::shared_ptr<bool> decompressed;
};

class PackingKeyswitchKey {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ typedef struct FFT {
typedef struct RuntimeContext {

RuntimeContext() = delete;
RuntimeContext(ServerKeyset &serverKeyset);
RuntimeContext(ServerKeyset serverKeyset);
virtual ~RuntimeContext() {
#ifdef CONCRETELANG_CUDA_SUPPORT
for (int i = 0; i < num_devices; ++i) {
Expand Down Expand Up @@ -71,8 +71,10 @@ typedef struct RuntimeContext {

virtual const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }

const ServerKeyset getKeys() const { return serverKeyset; }

protected:
ServerKeyset &serverKeyset;
ServerKeyset serverKeyset;
std::vector<std::shared_ptr<std::vector<std::complex<double>>>>
fourier_bootstrap_keys;
std::vector<FFT> ffts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ServerCircuit {

public:
/// Call the circuit with public arguments.
Result<std::vector<TransportValue>> call(ServerKeyset &serverKeyset,
Result<std::vector<TransportValue>> call(const ServerKeyset &serverKeyset,
std::vector<TransportValue> &args);

/// Simulate the circuit with public arguments.
Expand All @@ -65,7 +65,7 @@ class ServerCircuit {
std::shared_ptr<DynamicModule> dynamicModule,
bool useSimulation);

void invoke(ServerKeyset &serverKeyset);
void invoke(const ServerKeyset &serverKeyset);

Message<concreteprotocol::CircuitInfo> circuitInfo;
bool useSimulation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
::concretelang::clientlib::PublicArguments &publicArguments,
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
SignalGuard signalGuard;
auto &keyset = evaluationKeys.keyset;
auto keyset = evaluationKeys.keyset;
auto values = publicArguments.values;
GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values));
::concretelang::clientlib::PublicResult res{output};
Expand Down
12 changes: 6 additions & 6 deletions compilers/concrete-compiler/compiler/lib/Common/Keys.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ void LweBootstrapKey::decompress() {
case concreteprotocol::Compression::NONE:
return;
case concreteprotocol::Compression::SEED: {
if (decompressed)
if (*decompressed)
return;
const std::lock_guard<std::mutex> guard(*decompress_mutext);
if (decompressed)
if (*decompressed)
return;
auto params = info.asReader().getParams();
buffer->resize(concrete_cpu_bootstrap_key_size_u64(
Expand All @@ -204,7 +204,7 @@ void LweBootstrapKey::decompress() {
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
params.getPolynomialSize(), params.getGlweDimension(),
params.getLevelCount(), params.getBaseLog(), seed, Parallelism::Rayon);
decompressed = true;
*decompressed = true;
return;
}
default:
Expand Down Expand Up @@ -309,10 +309,10 @@ void LweKeyswitchKey::decompress() {
case concreteprotocol::Compression::NONE:
return;
case concreteprotocol::Compression::SEED: {
if (decompressed)
if (*decompressed)
return;
const std::lock_guard<std::mutex> guard(*decompress_mutext);
if (decompressed)
if (*decompressed)
return;
auto params = info.asReader().getParams();
buffer->resize(concrete_cpu_keyswitch_key_size_u64(
Expand All @@ -324,7 +324,7 @@ void LweKeyswitchKey::decompress() {
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
params.getOutputLweDimension(), params.getLevelCount(),
params.getBaseLog(), seed, Parallelism::Rayon);
decompressed = true;
*decompressed = true;
return;
}
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ FFT::~FFT() {
}
}

RuntimeContext::RuntimeContext(ServerKeyset &serverKeyset)
RuntimeContext::RuntimeContext(ServerKeyset serverKeyset)
: serverKeyset(serverKeyset) {

// Initialize for each bootstrap key the fourier one
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ bool getGateIsSigned(const Message<concreteprotocol::GateInfo> &gateInfo) {
}

Result<std::vector<TransportValue>>
ServerCircuit::call(ServerKeyset &serverKeyset,
ServerCircuit::call(const ServerKeyset &serverKeyset,
std::vector<TransportValue> &args) {
if (args.size() != argsBuffer.size()) {
return StringError("Called circuit with wrong number of arguments");
Expand Down Expand Up @@ -542,7 +542,7 @@ Result<ServerCircuit> ServerCircuit::fromDynamicModule(
return output;
}

void ServerCircuit::invoke(ServerKeyset &serverKeyset) {
void ServerCircuit::invoke(const ServerKeyset &serverKeyset) {

// We create a runtime context from the keyset, and place a pointer to it in
// the structure.
Expand Down
Loading