Skip to content

Commit

Permalink
feat(optimizer): add generic keyset info generation
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed Nov 21, 2024
1 parent 7ffb556 commit c6d7fd4
Show file tree
Hide file tree
Showing 10 changed files with 506 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef CONCRETELANG_COMMON_KEYSETS_H
#define CONCRETELANG_COMMON_KEYSETS_H

#include "concrete-optimizer.hpp"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
Expand Down Expand Up @@ -92,6 +93,10 @@ class KeysetCache {
KeysetCache() = default;
};

Message<concreteprotocol::KeysetInfo> generate_generic_keyset_info(
std::vector<concrete_optimizer::utils::PartitionDefinition> partitions,
bool generate_fks);

} // namespace keysets
} // namespace concretelang

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "concretelang/Support/Error.h"
#include "concretelang/Support/V0Parameters.h"
#include "concretelang/Support/logging.h"
#include <cstdint>
#include <filesystem>
#include <memory>
#include <mlir-c/Bindings/Python/Interop.h>
Expand Down Expand Up @@ -645,6 +646,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(LweSecretKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(LweSecretKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<LweSecretKeyParam>(m, "LweSecretKeyParam")
.def(
Expand All @@ -659,6 +672,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of an LWE Secret Key.";

// ------------------------------------------------------------------------------//
Expand Down Expand Up @@ -689,6 +704,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(BootstrapKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(BootstrapKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<BootstrapKeyParam>(m, "BootstrapKeyParam")
.def(
Expand Down Expand Up @@ -745,6 +772,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a Bootstrap key.";

// ------------------------------------------------------------------------------//
Expand All @@ -766,6 +795,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(KeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(KeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<KeyswitchKeyParam>(m, "KeyswitchKeyParam")
.def(
Expand Down Expand Up @@ -804,6 +845,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a keyswitch key.";

// ------------------------------------------------------------------------------//
Expand Down Expand Up @@ -834,6 +877,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(PackingKeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(PackingKeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<PackingKeyswitchKeyParam>(m, "PackingKeyswitchKeyParam")
.def(
Expand Down Expand Up @@ -892,13 +947,44 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a packing keyswitch key.";

// ------------------------------------------------------------------------------//
// PARTITION DEFINITION //
// ------------------------------------------------------------------------------//
//
pybind11::class_<concrete_optimizer::utils::PartitionDefinition>(
m, "PartitionDefinition")
.def(init([](uint8_t precision, double norm2)
-> concrete_optimizer::utils::PartitionDefinition {
return concrete_optimizer::utils::PartitionDefinition{precision,
norm2};
}),
arg("precision"), arg("norm2"))
.doc() = "Definition of a partition (in terms of precision in bits and "
"norm2 in value).";

// ------------------------------------------------------------------------------//
// KEYSET INFO //
// ------------------------------------------------------------------------------//
typedef Message<concreteprotocol::KeysetInfo> KeysetInfo;
pybind11::class_<KeysetInfo>(m, "KeysetInfo")
.def_static(
"generate_generic",
[](std::vector<concrete_optimizer::utils::PartitionDefinition>
partitions,
bool generateFks) -> KeysetInfo {
if (partitions.size() < 2) {
throw std::runtime_error("Need at least two partition defs to "
"generate a generic keyset info.");
}
return ::concretelang::keysets::generate_generic_keyset_info(
partitions, generateFks);
},
arg("partition_defs"), arg("generate_fks"),
"Generate a generic keyset info for a set of partition definitions")
.def(
"secret_keys",
[](KeysetInfo &keysetInfo) {
Expand Down
94 changes: 94 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "concretelang/Common/Keysets.h"
#include "capnp/message.h"
#include "concrete-cpu.h"
#include "concrete-optimizer.hpp"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
Expand Down Expand Up @@ -417,5 +418,98 @@ KeysetCache::getKeyset(const Message<concreteprotocol::KeysetInfo> &keysetInfo,
return std::move(keyset);
}

Message<concreteprotocol::KeysetInfo> generate_generic_keyset_info(
std::vector<concrete_optimizer::utils::PartitionDefinition> partitionDefs,
bool generateFks) {
auto output = Message<concreteprotocol::KeysetInfo>{};
rust::Vec<concrete_optimizer::utils::PartitionDefinition> rustPartitionDefs{};
for (auto def : partitionDefs) {
rustPartitionDefs.push_back(def);
}
auto parameters = concrete_optimizer::utils::generate_generic_keyset_info(
rustPartitionDefs, generateFks);

auto skLen = (int)parameters.secret_keys.size();
auto skBuilder = output.asBuilder().initLweSecretKeys(skLen);
for (int i = 0; i < skLen; i++) {
auto output = Message<concreteprotocol::LweSecretKeyInfo>();
auto sk = parameters.secret_keys[i];
output.asBuilder().setId(sk.identifier);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setLweDimension(sk.polynomial_size *
sk.glwe_dimension);
output.asBuilder().getParams().setKeyType(
::concreteprotocol::KeyType::BINARY);
skBuilder.setWithCaveats(i, output.asReader());
}

auto bskLen = (int)parameters.bootstrap_keys.size();
auto bskBuilder = output.asBuilder().initLweBootstrapKeys(bskLen);
for (int i = 0; i < bskLen; i++) {
auto output = Message<concreteprotocol::LweBootstrapKeyInfo>();
auto bsk = parameters.bootstrap_keys[i];
output.asBuilder().setId(bsk.identifier);
output.asBuilder().setInputId(bsk.input_key.identifier);
output.asBuilder().setOutputId(bsk.output_key.identifier);
output.asBuilder().getParams().setLevelCount(
bsk.br_decomposition_parameter.level);
output.asBuilder().getParams().setBaseLog(
bsk.br_decomposition_parameter.log2_base);
output.asBuilder().getParams().setGlweDimension(
bsk.output_key.glwe_dimension);
output.asBuilder().getParams().setPolynomialSize(
bsk.output_key.polynomial_size);
output.asBuilder().getParams().setInputLweDimension(
bsk.input_key.polynomial_size);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setKeyType(
concreteprotocol::KeyType::BINARY);
bskBuilder.setWithCaveats(i, output.asReader());
}

auto kskLen = (int)parameters.keyswitch_keys.size();
auto ckskLen = (int)parameters.conversion_keyswitch_keys.size();
auto kskBuilder = output.asBuilder().initLweKeyswitchKeys(kskLen + ckskLen);
for (int i = 0; i < kskLen; i++) {
auto output = Message<concreteprotocol::LweKeyswitchKeyInfo>();
auto ksk = parameters.keyswitch_keys[i];
output.asBuilder().setId(ksk.identifier);
output.asBuilder().setInputId(ksk.input_key.identifier);
output.asBuilder().setOutputId(ksk.output_key.identifier);
output.asBuilder().getParams().setLevelCount(
ksk.ks_decomposition_parameter.level);
output.asBuilder().getParams().setBaseLog(
ksk.ks_decomposition_parameter.log2_base);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setInputLweDimension(
ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size);
output.asBuilder().getParams().setOutputLweDimension(
ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size);
output.asBuilder().getParams().setKeyType(
concreteprotocol::KeyType::BINARY);
kskBuilder.setWithCaveats(i, output.asReader());
}
for (int i = 0; i < ckskLen; i++) {
auto output = Message<concreteprotocol::LweKeyswitchKeyInfo>();
auto ksk = parameters.conversion_keyswitch_keys[i];
output.asBuilder().setId(ksk.identifier);
output.asBuilder().setInputId(ksk.input_key.identifier);
output.asBuilder().setOutputId(ksk.output_key.identifier);
output.asBuilder().getParams().setLevelCount(
ksk.ks_decomposition_parameter.level);
output.asBuilder().getParams().setBaseLog(
ksk.ks_decomposition_parameter.log2_base);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setInputLweDimension(
ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size);
output.asBuilder().getParams().setOutputLweDimension(
ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size);
output.asBuilder().getParams().setKeyType(
concreteprotocol::KeyType::BINARY);
kskBuilder.setWithCaveats(i + kskLen, output.asReader());
}
return output;
}

} // namespace keysets
} // namespace concretelang
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use concrete_optimizer::dag::operator::{
};
use concrete_optimizer::dag::unparametrized;
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::dag::multi_parameters::generic_generation::generate_generic_parameters;
use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::CircuitSolution;
use concrete_optimizer::optimization::dag::multi_parameters::optimize::{
KeysetRestriction, MacroParameters, NoSearchSpaceRestriction, RangeRestriction,
Expand Down Expand Up @@ -913,6 +914,22 @@ fn location_from_string(string: &str) -> Box<Location> {
}
}

fn generate_generic_keyset_info(
inputs: Vec<ffi::PartitionDefinition>,
generate_fks: bool,
) -> ffi::CircuitKeys {
generate_generic_parameters(
inputs
.into_iter()
.map(
|ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::generic_generation::PartitionDefinition { precision, norm2 },
)
.collect(),
generate_fks,
)
.into()
}

pub struct Weights(operator::Weights);

fn vector(weights: &[i64]) -> Box<Weights> {
Expand Down Expand Up @@ -981,6 +998,12 @@ mod ffi {
#[namespace = "concrete_optimizer::utils"]
fn location_from_string(string: &str) -> Box<Location>;

#[namespace = "concrete_optimizer::utils"]
fn generate_generic_keyset_info(
partitions: Vec<PartitionDefinition>,
generate_fks: bool,
) -> CircuitKeys;

#[namespace = "concrete_optimizer::utils"]
fn get_external_partition(
name: String,
Expand Down Expand Up @@ -1359,6 +1382,13 @@ mod ffi {
pub struct KeysetRestriction {
pub info: KeysetInfo,
}

#[namespace = "concrete_optimizer::utils"]
#[derive(Debug, Clone)]
pub struct PartitionDefinition {
pub precision: u8,
pub norm2: f64,
}
}

fn processing_unit(options: &ffi::Options) -> ProcessingUnit {
Expand Down
Loading

0 comments on commit c6d7fd4

Please sign in to comment.