From 690df726a559ffb9329d012b91c87da51d9da866 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Mon, 28 Oct 2024 11:47:08 +0100 Subject: [PATCH] feat(compiler): search space restriction --- .../include/concretelang/Common/Protocol.h | 12 + .../concretelang/Support/V0Parameters.h | 14 +- .../lib/Bindings/Python/CompilerAPIModule.cpp | 317 ++++++------ .../concrete/compiler/compilation_feedback.py | 14 +- .../compiler/lib/Support/V0Parameters.cpp | 12 +- .../src/concrete-optimizer.rs | 486 ++++++++--------- .../src/cpp/concrete-optimizer.cpp | 280 ++++++++-- .../src/cpp/concrete-optimizer.hpp | 117 +++-- .../concrete-optimizer-cpp/tests/src/main.cpp | 31 +- .../src/optimization/config.rs | 49 +- .../optimization/dag/multi_parameters/mod.rs | 2 + .../dag/multi_parameters/optimize/mod.rs | 156 ++++-- .../multi_parameters/optimize/restriction.rs | 487 ++++++++++++++++++ .../dag/multi_parameters/optimize/tests.rs | 45 +- .../dag/multi_parameters/optimize_generic.rs | 13 +- .../src/optimization/dag/solo_key/optimize.rs | 7 +- .../concrete/fhe/compilation/configuration.py | 9 + .../concrete/fhe/compilation/server.py | 5 + .../concrete/fhe/tfhers/bridge.py | 2 +- .../tests/compilation/test_restrictions.py | 98 ++++ 20 files changed, 1563 insertions(+), 593 deletions(-) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/restriction.rs create mode 100644 frontends/concrete-python/tests/compilation/test_restrictions.py diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h index d4695f4c06..bf11a32a33 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h @@ -114,6 +114,18 @@ template struct Message { return *this; } + bool operator==(Message const &other) const { + capnp::AnyStruct::Reader left = this->asReader(); + capnp::AnyStruct::Reader right = other.asReader(); + return left == right; + } + + bool operator!=(Message const &other) const { + capnp::AnyStruct::Reader left = this->asReader(); + capnp::AnyStruct::Reader right = other.asReader(); + return left != right; + } + Message(Message &&input) : message(nullptr) { regionBuilder = input.regionBuilder; message = input.message; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h index 3ef9b19e9b..ec685495ac 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h @@ -6,6 +6,8 @@ #ifndef CONCRETELANG_SUPPORT_V0Parameter_H_ #define CONCRETELANG_SUPPORT_V0Parameter_H_ +#include +#include #include #include "llvm/ADT/Optional.h" @@ -80,6 +82,10 @@ const concrete_optimizer::Encoding DEFAULT_ENCODING = const bool DEFAULT_CACHE_ON_DISK = true; const uint32_t DEFAULT_CIPHERTEXT_MODULUS_LOG = 64; const uint32_t DEFAULT_FFT_PRECISION = 53; +const std::shared_ptr + DEFAULT_RANGE_RESTRICTION = {}; +const std::shared_ptr + DEFAULT_KEYSET_RESTRICTION = {}; /// The strategy of the crypto optimization enum Strategy { @@ -124,7 +130,10 @@ struct Config { bool cache_on_disk; uint32_t ciphertext_modulus_log; uint32_t fft_precision; - concrete_optimizer::ParameterRestrictions parameter_restrictions; + std::shared_ptr + range_restriction; + std::shared_ptr + keyset_restriction; std::vector composition_rules; bool composable; }; @@ -142,7 +151,8 @@ const Config DEFAULT_CONFIG = {UNSPECIFIED_P_ERROR, DEFAULT_CACHE_ON_DISK, DEFAULT_CIPHERTEXT_MODULUS_LOG, DEFAULT_FFT_PRECISION, - concrete_optimizer::ParameterRestrictions{}, + DEFAULT_RANGE_RESTRICTION, + DEFAULT_KEYSET_RESTRICTION, DEFAULT_COMPOSITION_RULES, DEFAULT_COMPOSABLE}; diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 112bf2b91e..c1183a8128 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -269,6 +270,64 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .value("NATIVE", concrete_optimizer::Encoding::Native) .export_values(); + // ------------------------------------------------------------------------------// + // RANGE RESTRICTION // + // ------------------------------------------------------------------------------// + pybind11::class_( + m, "RangeRestriction") + .def(pybind11::init( + []() { return concrete_optimizer::restriction::RangeRestriction(); })) + .def( + "add_available_glwe_log_polynomial_size", + [](concrete_optimizer::restriction::RangeRestriction &restriction, + uint64_t v) { + restriction.glwe_log_polynomial_sizes.push_back(v); + }, + "Add an available glwe log poly size to the restriction") + .def( + "add_available_glwe_dimension", + [](concrete_optimizer::restriction::RangeRestriction &restriction, + uint64_t v) { restriction.glwe_dimensions.push_back(v); }, + "Add an available glwe dimension to the restriction") + .def( + "add_available_internal_lwe_dimension", + [](concrete_optimizer::restriction::RangeRestriction &restriction, + uint64_t v) { restriction.internal_lwe_dimensions.push_back(v); }, + "Add an available internal lwe dimension to the restriction") + .def( + "add_available_pbs_level_count", + [](concrete_optimizer::restriction::RangeRestriction &restriction, + uint64_t v) { restriction.pbs_level_count.push_back(v); }, + "Add an available pbs level count to the restriction") + .def( + "add_available_pbs_base_log", + [](concrete_optimizer::restriction::RangeRestriction &restriction, + uint64_t v) { restriction.pbs_base_log.push_back(v); }, + "Add an available pbs base log to the restriction") + .def( + "add_available_ks_level_count", + [](concrete_optimizer::restriction::RangeRestriction &restriction, + uint64_t v) { restriction.ks_level_count.push_back(v); }, + "Add an available ks level count to the restriction") + .def( + "add_available_ks_base_log", + [](concrete_optimizer::restriction::RangeRestriction &restriction, + uint64_t v) { restriction.ks_base_log.push_back(v); }, + "Add an available ks base log to the restriction") + .doc() = "Allow to restrict the optimizer parameter search space to a " + "set of values."; + + // ------------------------------------------------------------------------------// + // KEYSET RESTRICTION // + // ------------------------------------------------------------------------------// + pybind11::class_( + m, "KeysetRestriction") + .doc() = "Allow to restrict the optimizer search space to be compatible " + "with a keyset."; + + // ------------------------------------------------------------------------------// + // COMPILATION OPTIONS // + // ------------------------------------------------------------------------------// pybind11::class_(m, "CompilationOptions") .def(pybind11::init([](mlir::concretelang::Backend backend) { return CompilationOptions(backend); @@ -373,117 +432,19 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( options.optimizerConfig.security = security_level; }, "Set security level.", arg("security_level")) - .def("set_glwe_pbs_restrictions", + .def("set_range_restriction", [](CompilationOptions &options, - std::optional log2_polynomial_size_min, - std::optional log2_polynomial_size_max, - std::optional glwe_dimension_min, - std::optional glwe_dimension_max) { - if (log2_polynomial_size_min) { - options.optimizerConfig.parameter_restrictions.glwe_pbs - .log2_polynomial_size_min = - std::make_shared(*log2_polynomial_size_min); - } - if (log2_polynomial_size_max) { - options.optimizerConfig.parameter_restrictions.glwe_pbs - .log2_polynomial_size_max = - std::make_shared(*log2_polynomial_size_max); - } - if (glwe_dimension_min) { - options.optimizerConfig.parameter_restrictions.glwe_pbs - .glwe_dimension_min = - std::make_shared(*glwe_dimension_min); - } - if (glwe_dimension_max) { - options.optimizerConfig.parameter_restrictions.glwe_pbs - .glwe_dimension_max = - std::make_shared(*glwe_dimension_max); - } + concrete_optimizer::restriction::RangeRestriction restriction) { + options.optimizerConfig.range_restriction = std::make_shared< + concrete_optimizer::restriction::RangeRestriction>( + restriction); }) - .def("set_free_glwe_restrictions", + .def("set_keyset_restriction", [](CompilationOptions &options, - std::optional log2_polynomial_size_min, - std::optional log2_polynomial_size_max, - std::optional glwe_dimension_min, - std::optional glwe_dimension_max) { - if (log2_polynomial_size_min) { - options.optimizerConfig.parameter_restrictions.free_glwe - .log2_polynomial_size_min = - std::make_shared(*log2_polynomial_size_min); - } - if (log2_polynomial_size_max) { - options.optimizerConfig.parameter_restrictions.free_glwe - .log2_polynomial_size_max = - std::make_shared(*log2_polynomial_size_max); - } - if (glwe_dimension_min) { - options.optimizerConfig.parameter_restrictions.free_glwe - .glwe_dimension_min = - std::make_shared(*glwe_dimension_min); - } - if (glwe_dimension_max) { - options.optimizerConfig.parameter_restrictions.free_glwe - .glwe_dimension_max = - std::make_shared(*glwe_dimension_max); - } - }) - .def("set_br_decomposition_restrictions", - [](CompilationOptions &options, - std::optional log2_base_min, - std::optional log2_base_max, - std::optional level_min, - std::optional level_max) { - if (log2_base_min) { - options.optimizerConfig.parameter_restrictions.br_decomposition - .log2_base_min = std::make_shared(*log2_base_min); - } - if (log2_base_max) { - options.optimizerConfig.parameter_restrictions.br_decomposition - .log2_base_max = std::make_shared(*log2_base_max); - } - if (level_min) { - options.optimizerConfig.parameter_restrictions.br_decomposition - .level_min = std::make_shared(*level_min); - } - if (level_max) { - options.optimizerConfig.parameter_restrictions.br_decomposition - .level_max = std::make_shared(*level_max); - } - }) - .def("set_ks_decomposition_restrictions", - [](CompilationOptions &options, - std::optional log2_base_min, - std::optional log2_base_max, - std::optional level_min, - std::optional level_max) { - if (log2_base_min) { - options.optimizerConfig.parameter_restrictions.ks_decomposition - .log2_base_min = std::make_shared(*log2_base_min); - } - if (log2_base_max) { - options.optimizerConfig.parameter_restrictions.ks_decomposition - .log2_base_max = std::make_shared(*log2_base_max); - } - if (level_min) { - options.optimizerConfig.parameter_restrictions.ks_decomposition - .level_min = std::make_shared(*level_min); - } - if (level_max) { - options.optimizerConfig.parameter_restrictions.ks_decomposition - .level_max = std::make_shared(*level_max); - } - }) - .def("set_free_lwe_restrictions", - [](CompilationOptions &options, std::optional free_lwe_min, - std::optional free_lwe_max) { - if (free_lwe_min) { - options.optimizerConfig.parameter_restrictions.free_lwe_min = - std::make_shared(*free_lwe_min); - } - if (free_lwe_max) { - options.optimizerConfig.parameter_restrictions.free_lwe_max = - std::make_shared(*free_lwe_max); - } + concrete_optimizer::restriction::KeysetRestriction restriction) { + options.optimizerConfig.keyset_restriction = std::make_shared< + concrete_optimizer::restriction::KeysetRestriction>( + restriction); }) .def( "set_v0_parameter", @@ -933,6 +894,85 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( }) .doc() = "Parameters of a packing keyswitch key."; + // ------------------------------------------------------------------------------// + // KEYSET INFO // + // ------------------------------------------------------------------------------// + typedef Message KeysetInfo; + pybind11::class_(m, "KeysetInfo") + .def( + "secret_keys", + [](KeysetInfo &keysetInfo) { + auto secretKeys = std::vector(); + for (auto key : keysetInfo.asReader().getLweSecretKeys()) { + secretKeys.push_back(LweSecretKeyParam{key}); + } + return secretKeys; + }, + "Return the parameters of the secret keys for this keyset.") + .def( + "bootstrap_keys", + [](KeysetInfo &keysetInfo) { + auto bootstrapKeys = std::vector(); + for (auto key : keysetInfo.asReader().getLweBootstrapKeys()) { + bootstrapKeys.push_back(BootstrapKeyParam{key}); + } + return bootstrapKeys; + }, + "Return the parameters of the bootstrap keys for this keyset.") + .def( + "keyswitch_keys", + [](KeysetInfo &keysetInfo) { + auto keyswitchKeys = std::vector(); + for (auto key : keysetInfo.asReader().getLweKeyswitchKeys()) { + keyswitchKeys.push_back(KeyswitchKeyParam{key}); + } + return keyswitchKeys; + }, + "Return the parameters of the keyswitch keys for this keyset.") + .def( + "packing_keyswitch_keys", + [](KeysetInfo &keysetInfo) { + auto packingKeyswitchKeys = std::vector(); + for (auto key : keysetInfo.asReader().getPackingKeyswitchKeys()) { + packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{key}); + } + return packingKeyswitchKeys; + }, + "Return the parameters of the packing keyswitch keys for this " + "keyset.") + .def( + "get_restriction", + [](KeysetInfo &keysetInfo) { + concrete_optimizer::restriction::KeysetInfo output; + for (auto key : keysetInfo.asReader().getLweSecretKeys()) { + output.lwe_secret_keys.push_back( + concrete_optimizer::restriction::LweSecretKeyInfo{ + key.getParams().getLweDimension()}); + } + for (auto key : keysetInfo.asReader().getLweBootstrapKeys()) { + output.lwe_bootstrap_keys.push_back( + concrete_optimizer::restriction::LweBootstrapKeyInfo{ + key.getParams().getLevelCount(), + key.getParams().getBaseLog(), + key.getParams().getGlweDimension(), + key.getParams().getPolynomialSize(), + key.getParams().getInputLweDimension()}); + } + for (auto key : keysetInfo.asReader().getLweKeyswitchKeys()) { + output.lwe_keyswitch_keys.push_back( + concrete_optimizer::restriction::LweKeyswitchKeyInfo{ + key.getParams().getLevelCount(), + key.getParams().getBaseLog(), + key.getParams().getInputLweDimension(), + key.getParams().getOutputLweDimension()}); + } + return concrete_optimizer::restriction::KeysetRestriction{output}; + }, + "Return the search space restriction associated to this keyset info.") + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) + .doc() = "Parameters of a complete keyset."; + // ------------------------------------------------------------------------------// // TYPE INFO // // ------------------------------------------------------------------------------// @@ -1168,55 +1208,13 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return result; }, "Return the signedness of the input of the first circuit.") + .def( - "secret_keys", - [](ProgramInfo &programInfo) { - auto secretKeys = std::vector(); - for (auto key : programInfo.programInfo.asReader() - .getKeyset() - .getLweSecretKeys()) { - secretKeys.push_back(LweSecretKeyParam{key}); - } - return secretKeys; - }, - "Return the parameters of the secret keys for this program.") - .def( - "bootstrap_keys", - [](ProgramInfo &programInfo) { - auto bootstrapKeys = std::vector(); - for (auto key : programInfo.programInfo.asReader() - .getKeyset() - .getLweBootstrapKeys()) { - bootstrapKeys.push_back(BootstrapKeyParam{key}); - } - return bootstrapKeys; - }, - "Return the parameters of the bootstrap keys for this program.") - .def( - "keyswitch_keys", - [](ProgramInfo &programInfo) { - auto keyswitchKeys = std::vector(); - for (auto key : programInfo.programInfo.asReader() - .getKeyset() - .getLweKeyswitchKeys()) { - keyswitchKeys.push_back(KeyswitchKeyParam{key}); - } - return keyswitchKeys; - }, - "Return the parameters of the keyswitch keys for this program.") - .def( - "packing_keyswitch_keys", - [](ProgramInfo &programInfo) { - auto packingKeyswitchKeys = std::vector(); - for (auto key : programInfo.programInfo.asReader() - .getKeyset() - .getPackingKeyswitchKeys()) { - packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{key}); - } - return packingKeyswitchKeys; + "get_keyset_info", + [](ProgramInfo &programInfo) -> KeysetInfo { + return programInfo.programInfo.asReader().getKeyset(); }, - "Return the parameters of the packing keyswitch keys for this " - "program.") + "Return the keyset info associated to the program.") .def( "get_circuits", [](ProgramInfo &programInfo) { @@ -1297,7 +1295,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( std::make_shared>(std::move(glwe_sk)), params.info); }, - "Deserialize an LweSecretKey from glwe encoded (tfhe-rs compatible) " + "Deserialize an LweSecretKey from glwe encoded (tfhe-rs " + "compatible) " "bytes and associated parameters.", arg("buffer"), arg("params")) .def( @@ -1852,7 +1851,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( GET_OR_THROW_RESULT(auto ok, circuit.processOutput(result, pos)); return ok; }, - "Process a `pos` positional result `result` retrieved from server. ", + "Process a `pos` positional result `result` retrieved from " + "server. ", arg("result"), arg("pos")) .def( "simulate_prepare_input", @@ -1866,7 +1866,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( typeTransformer(arg), pos)); return ok; }, - "SIMULATE preparation of `pos` positional argument `arg` to be sent " + "SIMULATE preparation of `pos` positional argument `arg` to be " + "sent " "to server. DOES NOT NCRYPT.", arg("arg"), arg("pos")) .def( diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index 294c36e118..acd71f6375 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -100,13 +100,19 @@ def count_per_parameter( continue if key_type == KeyType.SECRET: - parameter = program_info.secret_keys()[key_index] + parameter = program_info.get_keyset_info().secret_keys()[key_index] elif key_type == KeyType.BOOTSTRAP: - parameter = program_info.bootstrap_keys()[key_index] + parameter = program_info.get_keyset_info().bootstrap_keys()[ + key_index + ] elif key_type == KeyType.KEY_SWITCH: - parameter = program_info.keyswitch_keys()[key_index] + parameter = program_info.get_keyset_info().keyswitch_keys()[ + key_index + ] elif key_type == KeyType.PACKING_KEY_SWITCH: - parameter = program_info.packing_keyswitch_keys()[key_index] + parameter = program_info.get_keyset_info().packing_keyswitch_keys()[ + key_index + ] else: assert False if parameter not in result: diff --git a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp index 6476440763..ccf5ed30b9 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp @@ -36,7 +36,17 @@ concrete_optimizer::Options options_from_config(optimizer::Config config) { /* .cache_on_disk = */ config.cache_on_disk, /* .ciphertext_modulus_log = */ config.ciphertext_modulus_log, /* .fft_precision = */ config.fft_precision, - /* .parameter_restrictions = */ config.parameter_restrictions}; + /* .range_restriction = */ + std::shared_ptr(), + /* .keyset_restriction = */ + std::shared_ptr(), + }; + if (config.range_restriction) { + options.range_restriction = config.range_restriction; + } + if (config.keyset_restriction) { + options.keyset_restriction = config.keyset_restriction; + } return options; } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index 06f38c9eb9..b195d12424 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -10,12 +10,14 @@ use concrete_optimizer::dag::operator::{ self, FunctionTable, LevelledComplexity, OperatorIndex, Precision, Shape, }; use concrete_optimizer::dag::unparametrized; -use concrete_optimizer::global_parameters::{ParameterDomains, DEFAULT_DOMAINS}; use concrete_optimizer::optimization::config::{Config, SearchSpace}; -use concrete_optimizer::optimization::dag::multi_parameters::keys_spec; use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::CircuitSolution; -use concrete_optimizer::optimization::dag::multi_parameters::optimize::MacroParameters; +use concrete_optimizer::optimization::dag::multi_parameters::optimize::{ + KeysetRestriction, MacroParameters, NoSearchSpaceRestriction, RangeRestriction, + SearchSpaceRestriction, +}; use concrete_optimizer::optimization::dag::multi_parameters::partition_cut::PartitionCut; +use concrete_optimizer::optimization::dag::multi_parameters::{keys_spec, PartitionIndex}; use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{ Encoding, Solution as DagSolution, }; @@ -59,195 +61,6 @@ fn caches_from(options: &ffi::Options) -> decomposition::PersistDecompCaches { ) } -fn calculate_parameter_domain(options: &ffi::Options) -> ParameterDomains { - let mut domains = DEFAULT_DOMAINS; - - if !options - .parameter_restrictions - .glwe_pbs - .log2_polynomial_size_min - .is_null() - { - domains.glwe_pbs_constrained_cpu.log2_polynomial_size.start = *options - .parameter_restrictions - .glwe_pbs - .log2_polynomial_size_min; - domains.glwe_pbs_constrained_gpu.log2_polynomial_size.start = *options - .parameter_restrictions - .glwe_pbs - .log2_polynomial_size_min; - } - if !options - .parameter_restrictions - .glwe_pbs - .log2_polynomial_size_max - .is_null() - { - domains.glwe_pbs_constrained_cpu.log2_polynomial_size.end = *options - .parameter_restrictions - .glwe_pbs - .log2_polynomial_size_max; - domains.glwe_pbs_constrained_gpu.log2_polynomial_size.end = *options - .parameter_restrictions - .glwe_pbs - .log2_polynomial_size_max; - } - if !options - .parameter_restrictions - .glwe_pbs - .glwe_dimension_min - .is_null() - { - domains.glwe_pbs_constrained_cpu.glwe_dimension.start = - *options.parameter_restrictions.glwe_pbs.glwe_dimension_min; - domains.glwe_pbs_constrained_gpu.glwe_dimension.start = - *options.parameter_restrictions.glwe_pbs.glwe_dimension_min; - } - if !options - .parameter_restrictions - .glwe_pbs - .glwe_dimension_max - .is_null() - { - domains.glwe_pbs_constrained_cpu.glwe_dimension.end = - *options.parameter_restrictions.glwe_pbs.glwe_dimension_max; - domains.glwe_pbs_constrained_gpu.glwe_dimension.end = - *options.parameter_restrictions.glwe_pbs.glwe_dimension_max; - } - - if !options - .parameter_restrictions - .free_glwe - .log2_polynomial_size_min - .is_null() - { - domains.free_glwe.log2_polynomial_size.start = *options - .parameter_restrictions - .free_glwe - .log2_polynomial_size_min; - } - if !options - .parameter_restrictions - .free_glwe - .log2_polynomial_size_max - .is_null() - { - domains.free_glwe.log2_polynomial_size.end = *options - .parameter_restrictions - .free_glwe - .log2_polynomial_size_max; - } - if !options - .parameter_restrictions - .free_glwe - .glwe_dimension_min - .is_null() - { - domains.free_glwe.glwe_dimension.start = - *options.parameter_restrictions.free_glwe.glwe_dimension_min; - } - if !options - .parameter_restrictions - .free_glwe - .glwe_dimension_max - .is_null() - { - domains.free_glwe.glwe_dimension.end = - *options.parameter_restrictions.free_glwe.glwe_dimension_max; - } - - if !options - .parameter_restrictions - .br_decomposition - .log2_base_min - .is_null() - { - domains.br_decomposition.log2_base.start = *options - .parameter_restrictions - .br_decomposition - .log2_base_min; - } - if !options - .parameter_restrictions - .br_decomposition - .log2_base_max - .is_null() - { - domains.br_decomposition.log2_base.end = *options - .parameter_restrictions - .br_decomposition - .log2_base_max; - } - if !options - .parameter_restrictions - .br_decomposition - .level_min - .is_null() - { - domains.br_decomposition.level.start = - *options.parameter_restrictions.br_decomposition.level_min; - } - if !options - .parameter_restrictions - .br_decomposition - .level_max - .is_null() - { - domains.br_decomposition.level.end = - *options.parameter_restrictions.br_decomposition.level_max; - } - - if !options - .parameter_restrictions - .ks_decomposition - .log2_base_min - .is_null() - { - domains.ks_decomposition.log2_base.start = *options - .parameter_restrictions - .ks_decomposition - .log2_base_min; - } - if !options - .parameter_restrictions - .ks_decomposition - .log2_base_max - .is_null() - { - domains.ks_decomposition.log2_base.end = *options - .parameter_restrictions - .ks_decomposition - .log2_base_max; - } - if !options - .parameter_restrictions - .ks_decomposition - .level_min - .is_null() - { - domains.ks_decomposition.level.start = - *options.parameter_restrictions.ks_decomposition.level_min; - } - if !options - .parameter_restrictions - .ks_decomposition - .level_max - .is_null() - { - domains.ks_decomposition.level.end = - *options.parameter_restrictions.ks_decomposition.level_max; - } - - if !options.parameter_restrictions.free_lwe_min.is_null() { - domains.free_lwe.start = *options.parameter_restrictions.free_lwe_min; - } - if !options.parameter_restrictions.free_lwe_max.is_null() { - domains.free_lwe.end = *options.parameter_restrictions.free_lwe_max; - } - - domains -} - #[derive(Clone)] pub struct ExternalPartition( concrete_optimizer::optimization::dag::multi_parameters::partition_cut::ExternalPartition, @@ -317,8 +130,7 @@ fn optimize_bootstrap(precision: u64, noise_factor: f64, options: &ffi::Options) let sum_size = 1; - let parameter_restrictions = calculate_parameter_domain(options); - let search_space = SearchSpace::default(processing_unit, parameter_restrictions); + let search_space = SearchSpace::default(processing_unit); let result = concrete_optimizer::optimization::atomic_pattern::optimize_one( sum_size, @@ -768,8 +580,7 @@ impl Dag { complexity_model: &CpuComplexity::default(), }; - let parameter_restrictions = calculate_parameter_domain(options); - let search_space = SearchSpace::default(processing_unit, parameter_restrictions); + let search_space = SearchSpace::default(processing_unit); let encoding = options.encoding.into(); @@ -840,8 +651,7 @@ impl Dag { fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), }; - let parameter_restrictions = calculate_parameter_domain(options); - let search_space = SearchSpace::default(processing_unit, parameter_restrictions); + let search_space = SearchSpace::default(processing_unit); let encoding = options.encoding.into(); #[allow(clippy::wildcard_in_or_patterns)] @@ -852,15 +662,54 @@ impl Dag { ffi::MultiParamStrategy::ByPrecision | _ => PartitionCut::for_each_precision(&self.0), }; let circuit_sol = - concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize( - &self.0, - config, - &search_space, - encoding, - options.default_log_norm2_woppbs, - &caches_from(options), - &Some(p_cut), - ); + if !options.keyset_restriction.is_null() && !options.range_restriction.is_null() { + concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize( + &self.0, + config, + &search_space, + &( + (*options.keyset_restriction).clone(), + (*options.range_restriction).clone(), + ), + encoding, + options.default_log_norm2_woppbs, + &caches_from(options), + &Some(p_cut), + ) + } else if !options.keyset_restriction.is_null() { + concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize( + &self.0, + config, + &search_space, + &*options.keyset_restriction, + encoding, + options.default_log_norm2_woppbs, + &caches_from(options), + &Some(p_cut), + ) + } else if !options.range_restriction.is_null() { + concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize( + &self.0, + config, + &search_space, + &*options.range_restriction, + encoding, + options.default_log_norm2_woppbs, + &caches_from(options), + &Some(p_cut), + ) + } else { + concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize( + &self.0, + config, + &search_space, + &NoSearchSpaceRestriction, + encoding, + options.default_log_norm2_woppbs, + &caches_from(options), + &Some(p_cut), + ) + }; circuit_sol.into() } } @@ -1335,33 +1184,16 @@ mod ffi { ByPrecisionAndNorm2, } - #[namespace = "concrete_optimizer"] + #[namespace = "concrete_optimizer::restriction"] #[derive(Debug, Clone)] - pub struct GlweParameterRestrictions { - pub log2_polynomial_size_min: SharedPtr, - pub log2_polynomial_size_max: SharedPtr, - pub glwe_dimension_min: SharedPtr, - pub glwe_dimension_max: SharedPtr, - } - - #[namespace = "concrete_optimizer"] - #[derive(Debug, Clone)] - pub struct DecompositionParameterRestrictions { - pub log2_base_min: SharedPtr, - pub log2_base_max: SharedPtr, - pub level_min: SharedPtr, - pub level_max: SharedPtr, - } - - #[namespace = "concrete_optimizer"] - #[derive(Debug, Clone)] - pub struct ParameterRestrictions { - pub glwe_pbs: GlweParameterRestrictions, - pub free_glwe: GlweParameterRestrictions, - pub br_decomposition: DecompositionParameterRestrictions, - pub ks_decomposition: DecompositionParameterRestrictions, - pub free_lwe_min: SharedPtr, - pub free_lwe_max: SharedPtr, + pub struct RangeRestriction { + pub glwe_log_polynomial_sizes: Vec, + pub glwe_dimensions: Vec, + pub internal_lwe_dimensions: Vec, + pub pbs_level_count: Vec, + pub pbs_base_log: Vec, + pub ks_level_count: Vec, + pub ks_base_log: Vec, } #[namespace = "concrete_optimizer"] @@ -1377,7 +1209,8 @@ mod ffi { pub cache_on_disk: bool, pub ciphertext_modulus_log: u32, pub fft_precision: u32, - pub parameter_restrictions: ParameterRestrictions, + pub range_restriction: SharedPtr, // SharedPtr used for Options since optionals are not available... + pub keyset_restriction: SharedPtr, // SharedPtr used for Options since optionals are not available... } #[namespace = "concrete_optimizer::dag"] @@ -1487,6 +1320,45 @@ mod ffi { pub is_feasible: bool, pub error_msg: String, } + + #[namespace = "concrete_optimizer::restriction"] + #[derive(Debug, Clone)] + pub struct LweSecretKeyInfo { + pub lwe_dimension: u64, + } + + #[namespace = "concrete_optimizer::restriction"] + #[derive(Debug, Clone)] + pub struct LweBootstrapKeyInfo { + pub level_count: u64, + pub base_log: u64, + pub glwe_dimension: u64, + pub polynomial_size: u64, + pub input_lwe_dimension: u64, + } + + #[namespace = "concrete_optimizer::restriction"] + #[derive(Debug, Clone)] + pub struct LweKeyswitchKeyInfo { + pub level_count: u64, + pub base_log: u64, + pub input_lwe_dimension: u64, + pub output_lwe_dimension: u64, + } + + #[namespace = "concrete_optimizer::restriction"] + #[derive(Debug, Clone)] + pub struct KeysetInfo { + pub lwe_secret_keys: Vec, + pub lwe_bootstrap_keys: Vec, + pub lwe_keyswitch_keys: Vec, + } + + #[namespace = "concrete_optimizer::restriction"] + #[derive(Debug, Clone)] + pub struct KeysetRestriction { + pub info: KeysetInfo, + } } fn processing_unit(options: &ffi::Options) -> ProcessingUnit { @@ -1499,3 +1371,149 @@ fn processing_unit(options: &ffi::Options) -> ProcessingUnit { config::ProcessingUnit::Cpu } } + +impl SearchSpaceRestriction for ffi::RangeRestriction { + fn is_available_glwe(&self, partition: PartitionIndex, glwe_params: GlweParameters) -> bool { + unsafe { + std::mem::transmute::<&Self, &RangeRestriction>(self) + .is_available_glwe(partition, glwe_params) + } + } + + fn is_available_macro( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &RangeRestriction>(self) + .is_available_macro(partition, macro_parameters) + } + } + + fn is_available_micro_pbs( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + pbs_parameters: BrDecompositionParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &RangeRestriction>(self).is_available_micro_pbs( + partition, + macro_parameters, + pbs_parameters, + ) + } + } + + fn is_available_micro_ks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &RangeRestriction>(self).is_available_micro_ks( + from_partition, + from_macro, + to_partition, + to_macro, + ks_parameters, + ) + } + } + + fn is_available_micro_fks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &RangeRestriction>(self).is_available_micro_fks( + from_partition, + from_macro, + to_partition, + to_macro, + ks_parameters, + ) + } + } +} + +impl SearchSpaceRestriction for ffi::KeysetRestriction { + fn is_available_glwe(&self, partition: PartitionIndex, glwe_params: GlweParameters) -> bool { + unsafe { + std::mem::transmute::<&Self, &KeysetRestriction>(self) + .is_available_glwe(partition, glwe_params) + } + } + + fn is_available_macro( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &KeysetRestriction>(self) + .is_available_macro(partition, macro_parameters) + } + } + + fn is_available_micro_pbs( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + pbs_parameters: BrDecompositionParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &KeysetRestriction>(self).is_available_micro_pbs( + partition, + macro_parameters, + pbs_parameters, + ) + } + } + + fn is_available_micro_ks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &KeysetRestriction>(self).is_available_micro_ks( + from_partition, + from_macro, + to_partition, + to_macro, + ks_parameters, + ) + } + } + + fn is_available_micro_fks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + unsafe { + std::mem::transmute::<&Self, &KeysetRestriction>(self).is_available_micro_fks( + from_partition, + from_macro, + to_partition, + to_macro, + ks_parameters, + ) + } + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 7851f883f0..6c6c8075cb 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -949,9 +949,6 @@ namespace concrete_optimizer { struct Weights; enum class Encoding : ::std::uint8_t; enum class MultiParamStrategy : ::std::uint8_t; - struct GlweParameterRestrictions; - struct DecompositionParameterRestrictions; - struct ParameterRestrictions; struct Options; namespace dag { struct OperatorIndex; @@ -969,6 +966,14 @@ namespace concrete_optimizer { namespace v0 { struct Solution; } + namespace restriction { + struct RangeRestriction; + struct LweSecretKeyInfo; + struct LweBootstrapKeyInfo; + struct LweKeyswitchKeyInfo; + struct KeysetInfo; + struct KeysetRestriction; + } } namespace concrete_optimizer { @@ -1140,43 +1145,22 @@ enum class MultiParamStrategy : ::std::uint8_t { }; #endif // CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$GlweParameterRestrictions -#define CXXBRIDGE1_STRUCT_concrete_optimizer$GlweParameterRestrictions -struct GlweParameterRestrictions final { - ::std::shared_ptr<::std::uint64_t> log2_polynomial_size_min; - ::std::shared_ptr<::std::uint64_t> log2_polynomial_size_max; - ::std::shared_ptr<::std::uint64_t> glwe_dimension_min; - ::std::shared_ptr<::std::uint64_t> glwe_dimension_max; - - using IsRelocatable = ::std::true_type; -}; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$GlweParameterRestrictions - -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$DecompositionParameterRestrictions -#define CXXBRIDGE1_STRUCT_concrete_optimizer$DecompositionParameterRestrictions -struct DecompositionParameterRestrictions final { - ::std::shared_ptr<::std::uint64_t> log2_base_min; - ::std::shared_ptr<::std::uint64_t> log2_base_max; - ::std::shared_ptr<::std::uint64_t> level_min; - ::std::shared_ptr<::std::uint64_t> level_max; - - using IsRelocatable = ::std::true_type; -}; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$DecompositionParameterRestrictions - -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$ParameterRestrictions -#define CXXBRIDGE1_STRUCT_concrete_optimizer$ParameterRestrictions -struct ParameterRestrictions final { - ::concrete_optimizer::GlweParameterRestrictions glwe_pbs; - ::concrete_optimizer::GlweParameterRestrictions free_glwe; - ::concrete_optimizer::DecompositionParameterRestrictions br_decomposition; - ::concrete_optimizer::DecompositionParameterRestrictions ks_decomposition; - ::std::shared_ptr<::std::uint64_t> free_lwe_min; - ::std::shared_ptr<::std::uint64_t> free_lwe_max; +namespace restriction { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$RangeRestriction +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$RangeRestriction +struct RangeRestriction final { + ::rust::Vec<::std::uint64_t> glwe_log_polynomial_sizes; + ::rust::Vec<::std::uint64_t> glwe_dimensions; + ::rust::Vec<::std::uint64_t> internal_lwe_dimensions; + ::rust::Vec<::std::uint64_t> pbs_level_count; + ::rust::Vec<::std::uint64_t> pbs_base_log; + ::rust::Vec<::std::uint64_t> ks_level_count; + ::rust::Vec<::std::uint64_t> ks_base_log; using IsRelocatable = ::std::true_type; }; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$ParameterRestrictions +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$RangeRestriction +} // namespace restriction #ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Options #define CXXBRIDGE1_STRUCT_concrete_optimizer$Options @@ -1191,7 +1175,8 @@ struct Options final { bool cache_on_disk; ::std::uint32_t ciphertext_modulus_log; ::std::uint32_t fft_precision; - ::concrete_optimizer::ParameterRestrictions parameter_restrictions; + ::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> range_restriction; + ::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> keyset_restriction; using IsRelocatable = ::std::true_type; }; @@ -1346,6 +1331,62 @@ struct CircuitSolution final { #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$dag$CircuitSolution } // namespace dag +namespace restriction { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweSecretKeyInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweSecretKeyInfo +struct LweSecretKeyInfo final { + ::std::uint64_t lwe_dimension; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweSecretKeyInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweBootstrapKeyInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweBootstrapKeyInfo +struct LweBootstrapKeyInfo final { + ::std::uint64_t level_count; + ::std::uint64_t base_log; + ::std::uint64_t glwe_dimension; + ::std::uint64_t polynomial_size; + ::std::uint64_t input_lwe_dimension; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweBootstrapKeyInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweKeyswitchKeyInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweKeyswitchKeyInfo +struct LweKeyswitchKeyInfo final { + ::std::uint64_t level_count; + ::std::uint64_t base_log; + ::std::uint64_t input_lwe_dimension; + ::std::uint64_t output_lwe_dimension; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweKeyswitchKeyInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetInfo +struct KeysetInfo final { + ::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> lwe_secret_keys; + ::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> lwe_bootstrap_keys; + ::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> lwe_keyswitch_keys; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction +struct KeysetRestriction final { + ::concrete_optimizer::restriction::KeysetInfo info; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction +} // namespace restriction + namespace v0 { extern "C" { ::concrete_optimizer::v0::Solution concrete_optimizer$v0$cxxbridge1$optimize_bootstrap(::std::uint64_t precision, double noise_factor, ::concrete_optimizer::Options const &options) noexcept; @@ -1697,6 +1738,46 @@ void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$reserve_total(::ru void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$set_len(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t len) noexcept; void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$truncate(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t len) noexcept; +static_assert(sizeof(::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction>) == 2 * sizeof(void *), ""); +static_assert(alignof(::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction>) == alignof(void *), ""); +void cxxbridge1$shared_ptr$concrete_optimizer$restriction$RangeRestriction$null(::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> *ptr) noexcept { + ::new (ptr) ::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction>(); +} +::concrete_optimizer::restriction::RangeRestriction *cxxbridge1$shared_ptr$concrete_optimizer$restriction$RangeRestriction$uninit(::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> *ptr) noexcept { + ::concrete_optimizer::restriction::RangeRestriction *uninit = reinterpret_cast<::concrete_optimizer::restriction::RangeRestriction *>(new ::rust::MaybeUninit<::concrete_optimizer::restriction::RangeRestriction>); + ::new (ptr) ::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction>(uninit); + return uninit; +} +void cxxbridge1$shared_ptr$concrete_optimizer$restriction$RangeRestriction$clone(::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> const &self, ::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> *ptr) noexcept { + ::new (ptr) ::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction>(self); +} +::concrete_optimizer::restriction::RangeRestriction const *cxxbridge1$shared_ptr$concrete_optimizer$restriction$RangeRestriction$get(::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> const &self) noexcept { + return self.get(); +} +void cxxbridge1$shared_ptr$concrete_optimizer$restriction$RangeRestriction$drop(::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> *self) noexcept { + self->~shared_ptr(); +} + +static_assert(sizeof(::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction>) == 2 * sizeof(void *), ""); +static_assert(alignof(::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction>) == alignof(void *), ""); +void cxxbridge1$shared_ptr$concrete_optimizer$restriction$KeysetRestriction$null(::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> *ptr) noexcept { + ::new (ptr) ::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction>(); +} +::concrete_optimizer::restriction::KeysetRestriction *cxxbridge1$shared_ptr$concrete_optimizer$restriction$KeysetRestriction$uninit(::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> *ptr) noexcept { + ::concrete_optimizer::restriction::KeysetRestriction *uninit = reinterpret_cast<::concrete_optimizer::restriction::KeysetRestriction *>(new ::rust::MaybeUninit<::concrete_optimizer::restriction::KeysetRestriction>); + ::new (ptr) ::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction>(uninit); + return uninit; +} +void cxxbridge1$shared_ptr$concrete_optimizer$restriction$KeysetRestriction$clone(::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> const &self, ::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> *ptr) noexcept { + ::new (ptr) ::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction>(self); +} +::concrete_optimizer::restriction::KeysetRestriction const *cxxbridge1$shared_ptr$concrete_optimizer$restriction$KeysetRestriction$get(::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> const &self) noexcept { + return self.get(); +} +void cxxbridge1$shared_ptr$concrete_optimizer$restriction$KeysetRestriction$drop(::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> *self) noexcept { + self->~shared_ptr(); +} + void cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$new(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> const *ptr) noexcept; void cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$drop(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> *ptr) noexcept; ::std::size_t cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$len(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> const *ptr) noexcept; @@ -1759,6 +1840,33 @@ ::concrete_optimizer::dag::InstructionKeys const *cxxbridge1$rust_vec$concrete_o void cxxbridge1$rust_vec$concrete_optimizer$dag$InstructionKeys$reserve_total(::rust::Vec<::concrete_optimizer::dag::InstructionKeys> *ptr, ::std::size_t new_cap) noexcept; void cxxbridge1$rust_vec$concrete_optimizer$dag$InstructionKeys$set_len(::rust::Vec<::concrete_optimizer::dag::InstructionKeys> *ptr, ::std::size_t len) noexcept; void cxxbridge1$rust_vec$concrete_optimizer$dag$InstructionKeys$truncate(::rust::Vec<::concrete_optimizer::dag::InstructionKeys> *ptr, ::std::size_t len) noexcept; + +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$new(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$drop(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$len(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> const *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$capacity(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> const *ptr) noexcept; +::concrete_optimizer::restriction::LweSecretKeyInfo const *cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$data(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$reserve_total(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> *ptr, ::std::size_t new_cap) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$set_len(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> *ptr, ::std::size_t len) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$truncate(::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> *ptr, ::std::size_t len) noexcept; + +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$new(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$drop(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$len(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> const *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$capacity(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> const *ptr) noexcept; +::concrete_optimizer::restriction::LweBootstrapKeyInfo const *cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$data(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$reserve_total(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> *ptr, ::std::size_t new_cap) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$set_len(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> *ptr, ::std::size_t len) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$truncate(::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> *ptr, ::std::size_t len) noexcept; + +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$new(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$drop(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$len(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> const *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$capacity(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> const *ptr) noexcept; +::concrete_optimizer::restriction::LweKeyswitchKeyInfo const *cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$data(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$reserve_total(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> *ptr, ::std::size_t new_cap) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$set_len(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> *ptr, ::std::size_t len) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$truncate(::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> *ptr, ::std::size_t len) noexcept; } // extern "C" namespace rust { @@ -2079,5 +2187,101 @@ template <> void Vec<::concrete_optimizer::dag::InstructionKeys>::truncate(::std::size_t len) { return cxxbridge1$rust_vec$concrete_optimizer$dag$InstructionKeys$truncate(this, len); } +template <> +Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::Vec() noexcept { + cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$new(this); +} +template <> +void Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::drop() noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$drop(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::size() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$len(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::capacity() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$capacity(this); +} +template <> +::concrete_optimizer::restriction::LweSecretKeyInfo const *Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::data() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$data(this); +} +template <> +void Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::reserve_total(::std::size_t new_cap) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$reserve_total(this, new_cap); +} +template <> +void Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::set_len(::std::size_t len) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$set_len(this, len); +} +template <> +void Vec<::concrete_optimizer::restriction::LweSecretKeyInfo>::truncate(::std::size_t len) { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweSecretKeyInfo$truncate(this, len); +} +template <> +Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::Vec() noexcept { + cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$new(this); +} +template <> +void Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::drop() noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$drop(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::size() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$len(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::capacity() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$capacity(this); +} +template <> +::concrete_optimizer::restriction::LweBootstrapKeyInfo const *Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::data() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$data(this); +} +template <> +void Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::reserve_total(::std::size_t new_cap) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$reserve_total(this, new_cap); +} +template <> +void Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::set_len(::std::size_t len) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$set_len(this, len); +} +template <> +void Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo>::truncate(::std::size_t len) { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweBootstrapKeyInfo$truncate(this, len); +} +template <> +Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::Vec() noexcept { + cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$new(this); +} +template <> +void Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::drop() noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$drop(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::size() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$len(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::capacity() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$capacity(this); +} +template <> +::concrete_optimizer::restriction::LweKeyswitchKeyInfo const *Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::data() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$data(this); +} +template <> +void Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::reserve_total(::std::size_t new_cap) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$reserve_total(this, new_cap); +} +template <> +void Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::set_len(::std::size_t len) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$set_len(this, len); +} +template <> +void Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo>::truncate(::std::size_t len) { + return cxxbridge1$rust_vec$concrete_optimizer$restriction$LweKeyswitchKeyInfo$truncate(this, len); +} } // namespace cxxbridge1 } // namespace rust diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 3f56c09391..336faa20ad 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -930,9 +930,6 @@ namespace concrete_optimizer { struct Weights; enum class Encoding : ::std::uint8_t; enum class MultiParamStrategy : ::std::uint8_t; - struct GlweParameterRestrictions; - struct DecompositionParameterRestrictions; - struct ParameterRestrictions; struct Options; namespace dag { struct OperatorIndex; @@ -950,6 +947,14 @@ namespace concrete_optimizer { namespace v0 { struct Solution; } + namespace restriction { + struct RangeRestriction; + struct LweSecretKeyInfo; + struct LweBootstrapKeyInfo; + struct LweKeyswitchKeyInfo; + struct KeysetInfo; + struct KeysetRestriction; + } } namespace concrete_optimizer { @@ -1121,43 +1126,22 @@ enum class MultiParamStrategy : ::std::uint8_t { }; #endif // CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$GlweParameterRestrictions -#define CXXBRIDGE1_STRUCT_concrete_optimizer$GlweParameterRestrictions -struct GlweParameterRestrictions final { - ::std::shared_ptr<::std::uint64_t> log2_polynomial_size_min; - ::std::shared_ptr<::std::uint64_t> log2_polynomial_size_max; - ::std::shared_ptr<::std::uint64_t> glwe_dimension_min; - ::std::shared_ptr<::std::uint64_t> glwe_dimension_max; - - using IsRelocatable = ::std::true_type; -}; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$GlweParameterRestrictions - -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$DecompositionParameterRestrictions -#define CXXBRIDGE1_STRUCT_concrete_optimizer$DecompositionParameterRestrictions -struct DecompositionParameterRestrictions final { - ::std::shared_ptr<::std::uint64_t> log2_base_min; - ::std::shared_ptr<::std::uint64_t> log2_base_max; - ::std::shared_ptr<::std::uint64_t> level_min; - ::std::shared_ptr<::std::uint64_t> level_max; - - using IsRelocatable = ::std::true_type; -}; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$DecompositionParameterRestrictions - -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$ParameterRestrictions -#define CXXBRIDGE1_STRUCT_concrete_optimizer$ParameterRestrictions -struct ParameterRestrictions final { - ::concrete_optimizer::GlweParameterRestrictions glwe_pbs; - ::concrete_optimizer::GlweParameterRestrictions free_glwe; - ::concrete_optimizer::DecompositionParameterRestrictions br_decomposition; - ::concrete_optimizer::DecompositionParameterRestrictions ks_decomposition; - ::std::shared_ptr<::std::uint64_t> free_lwe_min; - ::std::shared_ptr<::std::uint64_t> free_lwe_max; +namespace restriction { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$RangeRestriction +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$RangeRestriction +struct RangeRestriction final { + ::rust::Vec<::std::uint64_t> glwe_log_polynomial_sizes; + ::rust::Vec<::std::uint64_t> glwe_dimensions; + ::rust::Vec<::std::uint64_t> internal_lwe_dimensions; + ::rust::Vec<::std::uint64_t> pbs_level_count; + ::rust::Vec<::std::uint64_t> pbs_base_log; + ::rust::Vec<::std::uint64_t> ks_level_count; + ::rust::Vec<::std::uint64_t> ks_base_log; using IsRelocatable = ::std::true_type; }; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$ParameterRestrictions +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$RangeRestriction +} // namespace restriction #ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Options #define CXXBRIDGE1_STRUCT_concrete_optimizer$Options @@ -1172,7 +1156,8 @@ struct Options final { bool cache_on_disk; ::std::uint32_t ciphertext_modulus_log; ::std::uint32_t fft_precision; - ::concrete_optimizer::ParameterRestrictions parameter_restrictions; + ::std::shared_ptr<::concrete_optimizer::restriction::RangeRestriction> range_restriction; + ::std::shared_ptr<::concrete_optimizer::restriction::KeysetRestriction> keyset_restriction; using IsRelocatable = ::std::true_type; }; @@ -1327,6 +1312,62 @@ struct CircuitSolution final { #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$dag$CircuitSolution } // namespace dag +namespace restriction { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweSecretKeyInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweSecretKeyInfo +struct LweSecretKeyInfo final { + ::std::uint64_t lwe_dimension; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweSecretKeyInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweBootstrapKeyInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweBootstrapKeyInfo +struct LweBootstrapKeyInfo final { + ::std::uint64_t level_count; + ::std::uint64_t base_log; + ::std::uint64_t glwe_dimension; + ::std::uint64_t polynomial_size; + ::std::uint64_t input_lwe_dimension; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweBootstrapKeyInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweKeyswitchKeyInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweKeyswitchKeyInfo +struct LweKeyswitchKeyInfo final { + ::std::uint64_t level_count; + ::std::uint64_t base_log; + ::std::uint64_t input_lwe_dimension; + ::std::uint64_t output_lwe_dimension; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$LweKeyswitchKeyInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetInfo +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetInfo +struct KeysetInfo final { + ::rust::Vec<::concrete_optimizer::restriction::LweSecretKeyInfo> lwe_secret_keys; + ::rust::Vec<::concrete_optimizer::restriction::LweBootstrapKeyInfo> lwe_bootstrap_keys; + ::rust::Vec<::concrete_optimizer::restriction::LweKeyswitchKeyInfo> lwe_keyswitch_keys; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetInfo + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction +#define CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction +struct KeysetRestriction final { + ::concrete_optimizer::restriction::KeysetInfo info; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction +} // namespace restriction + namespace v0 { ::concrete_optimizer::v0::Solution optimize_bootstrap(::std::uint64_t precision, double noise_factor, ::concrete_optimizer::Options const &options) noexcept; } // namespace v0 diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp index cde2ee8e10..673a9b79e1 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp @@ -30,35 +30,8 @@ concrete_optimizer::Options default_options() { .cache_on_disk = true, .ciphertext_modulus_log = CIPHERTEXT_MODULUS_LOG, .fft_precision = 53, - .parameter_restrictions = concrete_optimizer::ParameterRestrictions{ - .glwe_pbs = concrete_optimizer::GlweParameterRestrictions{ - .log2_polynomial_size_min = {}, - .log2_polynomial_size_max = {}, - .glwe_dimension_min = {}, - .glwe_dimension_max = {}, - }, - .free_glwe = concrete_optimizer::GlweParameterRestrictions{ - .log2_polynomial_size_min = {}, - .log2_polynomial_size_max = {}, - .glwe_dimension_min = {}, - .glwe_dimension_max = {}, - }, - - .br_decomposition = concrete_optimizer::DecompositionParameterRestrictions{ - .log2_base_min = {}, - .log2_base_max = {}, - .level_min = {}, - .level_max = {}, - }, - .ks_decomposition = concrete_optimizer::DecompositionParameterRestrictions{ - .log2_base_min = {}, - .log2_base_max = {}, - .level_min = {}, - .level_max = {}, - }, - .free_lwe_min = {}, - .free_lwe_max = {} - } + .range_restriction = {}, + .keyset_restriction = {} }; } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs index 56122f6b29..8ffe127cc3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs @@ -1,7 +1,7 @@ use crate::computing_cost::complexity_model::ComplexityModel; use crate::config; use crate::config::GpuPbsType; -use crate::global_parameters::{ParameterDomains, Range}; +use crate::global_parameters::{Range, DEFAULT_DOMAINS}; #[derive(Clone, Copy, Debug)] pub struct NoiseBoundConfig { @@ -37,17 +37,17 @@ pub struct SearchSpace { } impl SearchSpace { - pub fn default_cpu(parameter_domains: ParameterDomains) -> Self { - let glwe_log_polynomial_sizes: Vec = parameter_domains + pub fn default_cpu() -> Self { + let glwe_log_polynomial_sizes: Vec = DEFAULT_DOMAINS .glwe_pbs_constrained_cpu .log2_polynomial_size .as_vec(); - let glwe_dimensions: Vec = parameter_domains + let glwe_dimensions: Vec = DEFAULT_DOMAINS .glwe_pbs_constrained_cpu .glwe_dimension .as_vec(); - let internal_lwe_dimensions: Vec = parameter_domains.free_glwe.glwe_dimension.as_vec(); - let levelled_only_lwe_dimensions = parameter_domains.free_lwe; + let internal_lwe_dimensions: Vec = DEFAULT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + let levelled_only_lwe_dimensions = DEFAULT_DOMAINS.free_lwe; Self { glwe_log_polynomial_sizes, glwe_dimensions, @@ -56,20 +56,17 @@ impl SearchSpace { } } - pub fn default_gpu_lowlat(parameter_domains: ParameterDomains) -> Self { + pub fn default_gpu_lowlat() -> Self { // See backends/concrete_cuda/implementation/src/bootstrap_low_latency.cu - let glwe_log_polynomial_sizes: Vec = parameter_domains - .glwe_pbs_constrained_gpu - .log2_polynomial_size - .as_vec(); + let glwe_log_polynomial_sizes: Vec = (8..=14).collect(); - let glwe_dimensions: Vec = parameter_domains + let glwe_dimensions: Vec = DEFAULT_DOMAINS .glwe_pbs_constrained_gpu .glwe_dimension .as_vec(); - let internal_lwe_dimensions: Vec = parameter_domains.free_glwe.glwe_dimension.as_vec(); - let levelled_only_lwe_dimensions = parameter_domains.free_lwe; + let internal_lwe_dimensions: Vec = DEFAULT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + let levelled_only_lwe_dimensions = DEFAULT_DOMAINS.free_lwe; Self { glwe_log_polynomial_sizes, glwe_dimensions, @@ -78,20 +75,17 @@ impl SearchSpace { } } - pub fn default_gpu_amortized(parameter_domains: ParameterDomains) -> Self { + pub fn default_gpu_amortized() -> Self { // See backends/concrete_cuda/implementation/src/bootstrap_amortized.cu - let glwe_log_polynomial_sizes: Vec = parameter_domains - .glwe_pbs_constrained_gpu - .log2_polynomial_size - .as_vec(); + let glwe_log_polynomial_sizes: Vec = (8..=14).collect(); - let glwe_dimensions: Vec = parameter_domains + let glwe_dimensions: Vec = DEFAULT_DOMAINS .glwe_pbs_constrained_gpu .glwe_dimension .as_vec(); - let internal_lwe_dimensions: Vec = parameter_domains.free_glwe.glwe_dimension.as_vec(); - let levelled_only_lwe_dimensions = parameter_domains.free_lwe; + let internal_lwe_dimensions: Vec = DEFAULT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + let levelled_only_lwe_dimensions = DEFAULT_DOMAINS.free_lwe; Self { glwe_log_polynomial_sizes, glwe_dimensions, @@ -99,20 +93,17 @@ impl SearchSpace { levelled_only_lwe_dimensions, } } - pub fn default( - processing_unit: config::ProcessingUnit, - parameter_domains: ParameterDomains, - ) -> Self { + pub fn default(processing_unit: config::ProcessingUnit) -> Self { match processing_unit { - config::ProcessingUnit::Cpu => Self::default_cpu(parameter_domains), + config::ProcessingUnit::Cpu => Self::default_cpu(), config::ProcessingUnit::Gpu { pbs_type: GpuPbsType::Amortized, .. - } => Self::default_gpu_amortized(parameter_domains), + } => Self::default_gpu_amortized(), config::ProcessingUnit::Gpu { pbs_type: GpuPbsType::Lowlat, .. - } => Self::default_gpu_lowlat(parameter_domains), + } => Self::default_gpu_lowlat(), } } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 024e12e949..e0b66a59e7 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -13,3 +13,5 @@ pub(crate) mod variance_constraint; mod noise_expression; mod symbolic; + +pub use partitions::PartitionIndex; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index 413e2a8447..352f97d047 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -32,6 +32,9 @@ use super::symbolic::{bootstrap, fast_keyswitch, keyswitch}; const DEBUG: bool = false; +mod restriction; +pub use restriction::*; + #[derive(Debug, Clone)] pub struct MicroParameters { pub pbs: Vec>, @@ -79,6 +82,8 @@ type FksSrc = PartitionIndex; #[inline(never)] fn optimize_1_ks( + search_space_restriction: &impl SearchSpaceRestriction, + macro_parameters: &[MacroParameters], ks_src: KsSrc, ks_dst: KsDst, ks_input_lwe_dim: u64, @@ -93,6 +98,15 @@ fn optimize_1_ks( let ks_max_cost = complexity.evaluate_ks_max_cost(cut_complexity, &operations.cost, ks_src, ks_dst); for &ks_quantity in ks_pareto { + if !search_space_restriction.is_available_micro_ks( + ks_src, + macro_parameters[ks_src.0], + ks_dst, + macro_parameters[ks_dst.0], + ks_quantity.decomp, + ) { + continue; + } // variance is decreasing, complexity is increasing let ks_cost = ks_quantity.complexity(ks_input_lwe_dim); let ks_variance = ks_quantity.noise(ks_input_lwe_dim); @@ -111,6 +125,7 @@ fn optimize_1_ks( } fn optimize_many_independant_ks( + search_space_restriction: &impl SearchSpaceRestriction, macro_parameters: &[MacroParameters], ks_src: KsSrc, ks_input_lwe_dim: u64, @@ -139,6 +154,8 @@ fn optimize_many_independant_ks( let output_dim = macro_dst.internal_dim; let ks_pareto = caches.pareto_quantities(output_dim); let ks_best = optimize_1_ks( + search_space_restriction, + macro_parameters, ks_src, ks_dst, ks_input_lwe_dim, @@ -160,6 +177,7 @@ struct Best1FksAndManyKs { #[allow(clippy::type_complexity)] fn optimize_1_fks_and_all_compatible_ks( + search_space_restriction: &impl SearchSpaceRestriction, macro_parameters: &[MacroParameters], ks_used: &[Vec], fks_src: PartitionIndex, @@ -194,6 +212,16 @@ fn optimize_1_fks_and_all_compatible_ks( let mut fks_max_cost = complexity.evaluate_fks_max_cost(cut_complexity, &operations.cost, fks_src, fks_dst); for &ks_quantity in &ks_pareto { + if !search_space_restriction.is_available_micro_fks( + fks_src, + macro_parameters[fks_src.0], + fks_dst, + macro_parameters[fks_dst.0], + ks_quantity.decomp, + ) { + // Parameters unavailable.. + continue; + } // OPT: add a pareto cache for fks let fks_quantity = if same_dim { FksComplexityNoise { @@ -250,6 +278,7 @@ fn optimize_1_fks_and_all_compatible_ks( .set_variance(fast_keyswitch_noise(fks_src, fks_dst), fks_quantity.noise); let sol = optimize_many_independant_ks( + search_space_restriction, macro_parameters, ks_src, ks_input_dim, @@ -285,6 +314,7 @@ fn optimize_1_fks_and_all_compatible_ks( } fn optimize_dst_exclusive_fks_subset_and_all_ks( + search_space_restriction: &impl SearchSpaceRestriction, macro_parameters: &[MacroParameters], fks_paretos: &[Option], ks_used: &[Vec], @@ -307,6 +337,7 @@ fn optimize_dst_exclusive_fks_subset_and_all_ks( .sample_extract_lwe_dimension(); if let Some(fks_src) = maybe_fks_pareto { let (bests, operations) = optimize_1_fks_and_all_compatible_ks( + search_space_restriction, macro_parameters, ks_used, *fks_src, @@ -324,6 +355,7 @@ fn optimize_dst_exclusive_fks_subset_and_all_ks( } else { // There is no fks to optimize let (many_ks, operations) = optimize_many_independant_ks( + search_space_restriction, macro_parameters, PartitionIndex(ks_src), ks_input_lwe_dim, @@ -344,7 +376,8 @@ fn optimize_dst_exclusive_fks_subset_and_all_ks( fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( partition: PartitionIndex, macro_parameters: &[MacroParameters], - internal_dim: u64, + macro_param_partition: MacroParameters, + search_space_restriction: &impl SearchSpaceRestriction, cmux_pareto: &[CmuxComplexityNoise], fks_paretos: &[Option], ks_used: &[Vec], @@ -368,15 +401,24 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( for &cmux_quantity in cmux_pareto { // increasing complexity, decreasing variance + if !search_space_restriction.is_available_micro_pbs( + partition, + macro_param_partition, + cmux_quantity.decomp, + ) { + // Parameters are not available + continue; + } + // Lower bounds cuts - let pbs_cost = cmux_quantity.complexity_br(internal_dim); + let pbs_cost = cmux_quantity.complexity_br(macro_param_partition.internal_dim); operations.cost.set_cost(bootstrap(partition), pbs_cost); let lower_cost = complexity.evaluate_total_cost(&operations.cost); if lower_cost > best_sol_complexity { continue; } - let pbs_variance = cmux_quantity.noise_br(internal_dim); + let pbs_variance = cmux_quantity.noise_br(macro_param_partition.internal_dim); if pbs_variance > pbs_max_feasible_variance { continue; } @@ -385,6 +427,7 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( .variance .set_variance(bootstrap_noise(partition), pbs_variance); let sol = optimize_dst_exclusive_fks_subset_and_all_ks( + search_space_restriction, macro_parameters, fks_paretos, ks_used, @@ -663,6 +706,7 @@ fn optimize_macro( ciphertext_modulus_log: u32, fft_precision: u32, search_space: &SearchSpace, + search_space_restriction: &impl SearchSpaceRestriction, partition: PartitionIndex, used_tlu_keyswitch: &[Vec], used_conversion_keyswitch: &[Vec], @@ -704,6 +748,16 @@ fn optimize_macro( }); let mut lb_message = None; for (glwe_dimension, log2_polynomial_size) in glwe_params_domain { + if !search_space_restriction.is_available_glwe( + partition, + GlweParameters { + glwe_dimension, + log2_polynomial_size, + }, + ) { + // No parameters with these macro parameters are available in the search space. + continue; + } let glwe_params = GlweParameters { log2_polynomial_size, glwe_dimension, @@ -717,6 +771,16 @@ fn optimize_macro( } for &internal_dim in &search_space.internal_lwe_dimensions { + if !search_space_restriction.is_available_macro( + partition, + MacroParameters { + glwe_params, + internal_dim, + }, + ) { + // No parameters with these macro parameters are available in the search space + continue; + } let mut operations = operations.clone(); // OPT: fast linear noise_modulus_switching let variance_modulus_switching = @@ -850,7 +914,8 @@ fn optimize_macro( let micro_opt = optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( partition, ¯os, - internal_dim, + macro_param_partition, + search_space_restriction, cmux_pareto, &fks_to_optimize, used_tlu_keyswitch, @@ -926,10 +991,6 @@ fn optimize_macro( is_lower_bound, is_feasible: Feasibility::Feasible, }; - } else { - // the macro parameters are feasible - // but the complexity is not good enough due to previous feasible solution - assert!(best_parameters.is_feasible.is_feasible()); } } } @@ -949,6 +1010,7 @@ pub fn optimize( dag: &Dag, config: Config, search_space: &SearchSpace, + search_space_restriction: &impl SearchSpaceRestriction, persistent_caches: &PersistDecompCaches, p_cut: &Option, default_partition: PartitionIndex, @@ -962,11 +1024,6 @@ pub fn optimize( ciphertext_modulus_log, }; - let dag_p_cut = p_cut.as_ref().map_or_else( - || PartitionCut::for_each_precision(dag), - std::clone::Clone::clone, - ); - let dag = analyze(dag, &noise_config, p_cut, default_partition)?; let kappa = error::sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); @@ -1001,37 +1058,47 @@ pub fn optimize( let mut best_params: Option = None; for iter in 0..=10 { for partition in PartitionIndex::range(0, nb_partitions).rev() { - // reduce search space to the parameters of external partitions - let partition_search_space = if dag_p_cut.is_external_partition(&partition) { - let external_part = - &dag_p_cut.external_partitions[partition.0 - dag_p_cut.n_internal_partitions()]; - let mut reduced_search_space = search_space.clone(); - reduced_search_space.glwe_dimensions = - [external_part.macro_params.glwe_params.glwe_dimension].to_vec(); - reduced_search_space.glwe_log_polynomial_sizes = - [external_part.macro_params.glwe_params.log2_polynomial_size].to_vec(); - reduced_search_space.internal_lwe_dimensions = - [external_part.macro_params.internal_dim].to_vec(); - reduced_search_space - } else { - search_space.clone() + let new_params = match p_cut { + Some(p_cut) => { + let search_space_restriction = ( + search_space_restriction, + ExternalPartitionRestriction(p_cut.clone()), + ); + optimize_macro( + security_level, + ciphertext_modulus_log, + fft_precision, + search_space, + &search_space_restriction, + partition, + &used_tlu_keyswitch, + &used_conversion_keyswitch, + &feasible, + &complexity, + &mut caches, + ¶ms, + best_complexity, + best_p_error, + ) + } + None => optimize_macro( + security_level, + ciphertext_modulus_log, + fft_precision, + search_space, + search_space_restriction, + partition, + &used_tlu_keyswitch, + &used_conversion_keyswitch, + &feasible, + &complexity, + &mut caches, + ¶ms, + best_complexity, + best_p_error, + ), }; - let new_params = optimize_macro( - security_level, - ciphertext_modulus_log, - fft_precision, - &partition_search_space, - partition, - &used_tlu_keyswitch, - &used_conversion_keyswitch, - &feasible, - &complexity, - &mut caches, - ¶ms, - best_complexity, - best_p_error, - ); assert!( new_params.is_feasible.is_feasible() || !params.is_feasible.is_feasible(), "Cannot degrade feasibility" @@ -1095,6 +1162,9 @@ pub fn optimize( unfeasible_constraint.to_owned(), ))); } + Feasibility::Unknown => { + return Err(optimization::Err::NoParametersFound); + } _ => unreachable!(), } } @@ -1265,6 +1335,7 @@ pub fn optimize_to_circuit_solution( dag: &Dag, config: Config, search_space: &SearchSpace, + search_space_restriction: &impl SearchSpaceRestriction, persistent_caches: &PersistDecompCaches, p_cut: &Option, ) -> keys_spec::CircuitSolution { @@ -1283,6 +1354,7 @@ pub fn optimize_to_circuit_solution( dag, config, search_space, + search_space_restriction, persistent_caches, p_cut, default_partition, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/restriction.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/restriction.rs new file mode 100644 index 0000000000..8d0f8a5cec --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/restriction.rs @@ -0,0 +1,487 @@ +use crate::{ + optimization::dag::multi_parameters::{ + partition_cut::PartitionCut, partitions::PartitionIndex, + }, + parameters::{BrDecompositionParameters, GlweParameters, KsDecompositionParameters}, +}; + +use super::MacroParameters; + +/// A trait to restrict search space in the optimization algorithm. +/// +/// The trait methods are called at different level of the optimization algorithm to +/// perform cuts depending on whether the parameters are available. +pub trait SearchSpaceRestriction { + /// Return whether the glwe parameters are available for the given partition. + fn is_available_glwe(&self, partition: PartitionIndex, glwe_params: GlweParameters) -> bool; + + /// Return whether the macro parameters are available for the given partition. + fn is_available_macro( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool; + + /// Return whether the pbs parameters are available for the given partition. + fn is_available_micro_pbs( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + pbs_parameters: BrDecompositionParameters, + ) -> bool; + + /// Return whether the ks parameters are available for the given partitions. + fn is_available_micro_ks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool; + + /// Return whether the fks parameters are available for the given partitions. + fn is_available_micro_fks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool; +} + +// Allow references to restrictions to be used as restrictions. +impl SearchSpaceRestriction for &A { + fn is_available_glwe(&self, partition: PartitionIndex, glwe_params: GlweParameters) -> bool { + (*self).is_available_glwe(partition, glwe_params) + } + + fn is_available_macro( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool { + (*self).is_available_macro(partition, macro_parameters) + } + + fn is_available_micro_pbs( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + pbs_parameters: BrDecompositionParameters, + ) -> bool { + (*self).is_available_micro_pbs(partition, macro_parameters, pbs_parameters) + } + + fn is_available_micro_ks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + (*self).is_available_micro_ks( + from_partition, + from_macro, + to_partition, + to_macro, + ks_parameters, + ) + } + + fn is_available_micro_fks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + (*self).is_available_micro_fks( + from_partition, + from_macro, + to_partition, + to_macro, + ks_parameters, + ) + } +} + +// Allow tuples of restrictions to be used as restriction +macro_rules! impl_tuple { + ($($gen_ty: ident),*) => { + impl<$($gen_ty : SearchSpaceRestriction),*> SearchSpaceRestriction for ($($gen_ty),*){ + + #[allow(non_snake_case)] + fn is_available_glwe(&self, partition: PartitionIndex, glwe_params: GlweParameters) -> bool{ + let ($($gen_ty),*) = self; + $($gen_ty.is_available_glwe(partition, glwe_params))&&* + } + + #[allow(non_snake_case)] + fn is_available_macro( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool { + let ($($gen_ty),*) = self; + $($gen_ty.is_available_macro(partition, macro_parameters))&&* + } + + #[allow(non_snake_case)] + fn is_available_micro_pbs( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + pbs_parameters: BrDecompositionParameters, + ) -> bool { + let ($($gen_ty),*) = self; + $($gen_ty.is_available_micro_pbs(partition, macro_parameters, pbs_parameters))&&* + } + + #[allow(non_snake_case)] + fn is_available_micro_ks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + let ($($gen_ty),*) = self; + $($gen_ty.is_available_micro_ks(from_partition, from_macro, to_partition, to_macro, ks_parameters))&&* + } + + #[allow(non_snake_case)] + fn is_available_micro_fks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + let ($($gen_ty),*) = self; + $($gen_ty.is_available_micro_fks(from_partition, from_macro, to_partition, to_macro, ks_parameters))&&* + } + } + }; +} + +impl_tuple! {A, B} +impl_tuple! {A, B, C} +impl_tuple! {A, B, C, D} +impl_tuple! {A, B, C, D, E} + +/// A restriction performing no restriction at all. +pub struct NoSearchSpaceRestriction; + +impl SearchSpaceRestriction for NoSearchSpaceRestriction { + fn is_available_glwe(&self, _partition: PartitionIndex, _glwe_params: GlweParameters) -> bool { + true + } + + fn is_available_macro( + &self, + _partition: PartitionIndex, + _macro_parameters: MacroParameters, + ) -> bool { + true + } + + fn is_available_micro_pbs( + &self, + _partition: PartitionIndex, + _macro_parameters: MacroParameters, + _pbs_parameters: BrDecompositionParameters, + ) -> bool { + true + } + + fn is_available_micro_ks( + &self, + _from_partition: PartitionIndex, + _from_macro: MacroParameters, + _to_partition: PartitionIndex, + _to_macro: MacroParameters, + _ks_parameters: KsDecompositionParameters, + ) -> bool { + true + } + + fn is_available_micro_fks( + &self, + _from_partition: PartitionIndex, + _from_macro: MacroParameters, + _to_partition: PartitionIndex, + _to_macro: MacroParameters, + _ks_parameters: KsDecompositionParameters, + ) -> bool { + true + } +} + +/// An object restricting the search space based on smaller ranges. +pub struct RangeRestriction { + pub glwe_log_polynomial_sizes: Vec, + pub glwe_dimensions: Vec, + pub internal_lwe_dimensions: Vec, + pub pbs_level_count: Vec, + pub pbs_base_log: Vec, + pub ks_level_count: Vec, + pub ks_base_log: Vec, +} + +impl SearchSpaceRestriction for RangeRestriction { + fn is_available_glwe(&self, _partition: PartitionIndex, glwe_params: GlweParameters) -> bool { + (self.glwe_dimensions.is_empty() + || self.glwe_dimensions.contains(&glwe_params.glwe_dimension)) + && (self.glwe_log_polynomial_sizes.is_empty() + || self + .glwe_log_polynomial_sizes + .contains(&glwe_params.log2_polynomial_size)) + } + + fn is_available_macro( + &self, + _partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool { + (self.glwe_dimensions.is_empty() + || self + .glwe_dimensions + .contains(¯o_parameters.glwe_params.glwe_dimension)) + && (self.glwe_log_polynomial_sizes.is_empty() + || self + .glwe_log_polynomial_sizes + .contains(¯o_parameters.glwe_params.log2_polynomial_size)) + && (self.internal_lwe_dimensions.is_empty() + || self + .internal_lwe_dimensions + .contains(¯o_parameters.internal_dim)) + } + + fn is_available_micro_pbs( + &self, + _partition: PartitionIndex, + macro_parameters: MacroParameters, + pbs_parameters: BrDecompositionParameters, + ) -> bool { + (self.glwe_dimensions.is_empty() + || self + .glwe_dimensions + .contains(¯o_parameters.glwe_params.glwe_dimension)) + && (self.glwe_log_polynomial_sizes.is_empty() + || self + .glwe_log_polynomial_sizes + .contains(¯o_parameters.glwe_params.log2_polynomial_size)) + && (self.internal_lwe_dimensions.is_empty() + || self + .internal_lwe_dimensions + .contains(¯o_parameters.internal_dim)) + && (self.pbs_base_log.is_empty() + || self.pbs_base_log.contains(&pbs_parameters.log2_base)) + && (self.pbs_level_count.is_empty() + || self.pbs_level_count.contains(&pbs_parameters.level)) + } + + fn is_available_micro_ks( + &self, + _from_partition: PartitionIndex, + _from_macro: MacroParameters, + _to_partition: PartitionIndex, + _to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + (self.ks_base_log.is_empty() || self.ks_base_log.contains(&ks_parameters.log2_base)) + && (self.ks_level_count.is_empty() + || self.ks_level_count.contains(&ks_parameters.level)) + } + + fn is_available_micro_fks( + &self, + _from_partition: PartitionIndex, + _from_macro: MacroParameters, + _to_partition: PartitionIndex, + _to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + (self.ks_base_log.is_empty() || self.ks_base_log.contains(&ks_parameters.log2_base)) + && (self.ks_level_count.is_empty() + || self.ks_level_count.contains(&ks_parameters.level)) + } +} + +#[allow(unused)] +pub struct LweSecretKeyInfo { + lwe_dimension: u64, +} + +pub struct LweBootstrapKeyInfo { + level_count: u64, + base_log: u64, + glwe_dimension: u64, + polynomial_size: u64, + input_lwe_dimension: u64, +} + +pub struct LweKeyswitchKeyInfo { + level_count: u64, + base_log: u64, + input_lwe_dimension: u64, + output_lwe_dimension: u64, +} + +#[allow(unused)] +pub struct KeysetInfo { + lwe_secret_keys: Vec, + lwe_bootstrap_keys: Vec, + lwe_keyswitch_keys: Vec, +} + +/// An object restricting the search space based on a keyset. +pub struct KeysetRestriction { + info: KeysetInfo, +} + +impl SearchSpaceRestriction for KeysetRestriction { + fn is_available_glwe(&self, _partition: PartitionIndex, glwe_params: GlweParameters) -> bool { + return self.info.lwe_bootstrap_keys.iter().any(|k| { + k.glwe_dimension == glwe_params.glwe_dimension + && k.polynomial_size == 2_u64.pow(glwe_params.log2_polynomial_size as u32) + }); + } + + fn is_available_macro( + &self, + _partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool { + return self.info.lwe_bootstrap_keys.iter().any(|k| { + k.glwe_dimension == macro_parameters.glwe_params.glwe_dimension + && k.polynomial_size + == 2_u64.pow(macro_parameters.glwe_params.log2_polynomial_size as u32) + && k.input_lwe_dimension == macro_parameters.internal_dim + }); + } + + fn is_available_micro_pbs( + &self, + _partition: PartitionIndex, + macro_parameters: MacroParameters, + pbs_parameters: BrDecompositionParameters, + ) -> bool { + return self.info.lwe_bootstrap_keys.iter().any(|k| { + k.glwe_dimension == macro_parameters.glwe_params.glwe_dimension + && k.polynomial_size + == 2_u64.pow(macro_parameters.glwe_params.log2_polynomial_size as u32) + && k.input_lwe_dimension == macro_parameters.internal_dim + && k.level_count == pbs_parameters.level + && k.base_log == pbs_parameters.log2_base + }); + } + + fn is_available_micro_ks( + &self, + _from_partition: PartitionIndex, + from_macro: MacroParameters, + _to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + return self.info.lwe_keyswitch_keys.iter().any(|k| { + k.input_lwe_dimension == from_macro.glwe_params.sample_extract_lwe_dimension() + && k.output_lwe_dimension == to_macro.internal_dim + && k.level_count == ks_parameters.level + && k.base_log == ks_parameters.log2_base + }); + } + + fn is_available_micro_fks( + &self, + _from_partition: PartitionIndex, + from_macro: MacroParameters, + _to_partition: PartitionIndex, + to_macro: MacroParameters, + ks_parameters: KsDecompositionParameters, + ) -> bool { + return self.info.lwe_keyswitch_keys.iter().any(|k| { + k.input_lwe_dimension == from_macro.glwe_params.sample_extract_lwe_dimension() + && k.output_lwe_dimension == to_macro.glwe_params.sample_extract_lwe_dimension() + && k.level_count == ks_parameters.level + && k.base_log == ks_parameters.log2_base + }); + } +} + +/// An object restricting the search space for external partitions using partitioning informations. +pub struct ExternalPartitionRestriction(pub PartitionCut); + +impl SearchSpaceRestriction for ExternalPartitionRestriction { + fn is_available_glwe(&self, partition: PartitionIndex, glwe_params: GlweParameters) -> bool { + !self.0.is_external_partition(&partition) + || self.0.external_partitions[partition.0 - self.0.n_internal_partitions()] + .macro_params + .glwe_params + == glwe_params + } + + fn is_available_macro( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + ) -> bool { + !self.0.is_external_partition(&partition) + || self.0.external_partitions[partition.0 - self.0.n_internal_partitions()].macro_params + == macro_parameters + } + + fn is_available_micro_pbs( + &self, + partition: PartitionIndex, + macro_parameters: MacroParameters, + _pbs_parameters: BrDecompositionParameters, + ) -> bool { + !self.0.is_external_partition(&partition) + || self.0.external_partitions[partition.0 - self.0.n_internal_partitions()].macro_params + == macro_parameters + } + + fn is_available_micro_ks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + _ks_parameters: KsDecompositionParameters, + ) -> bool { + (!self.0.is_external_partition(&from_partition) + || self.0.external_partitions[from_partition.0 - self.0.n_internal_partitions()] + .macro_params + == from_macro) + && (!self.0.is_external_partition(&to_partition) + || self.0.external_partitions[to_partition.0 - self.0.n_internal_partitions()] + .macro_params + == to_macro) + } + + fn is_available_micro_fks( + &self, + from_partition: PartitionIndex, + from_macro: MacroParameters, + to_partition: PartitionIndex, + to_macro: MacroParameters, + _ks_parameters: KsDecompositionParameters, + ) -> bool { + (!self.0.is_external_partition(&from_partition) + || self.0.external_partitions[from_partition.0 - self.0.n_internal_partitions()] + .macro_params + == from_macro) + && (!self.0.is_external_partition(&to_partition) + || self.0.external_partitions[to_partition.0 - self.0.n_internal_partitions()] + .macro_params + == to_macro) + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index be91c3e252..4d811938b3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -7,7 +7,6 @@ use super::*; use crate::computing_cost::cpu::CpuComplexity; use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; use crate::dag::unparametrized; -use crate::global_parameters::DEFAULT_DOMAINS; use crate::optimization::dag::multi_parameters::partitionning::tests::{ get_tfhers_noise_br, SHARED_CACHES, TFHERS_MACRO_PARAMS, }; @@ -38,11 +37,12 @@ fn optimize( default_partition: PartitionIndex, ) -> Option { let config = default_config(); - let search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); + let search_space = SearchSpace::default_cpu(); super::optimize( dag, config, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, p_cut, default_partition, @@ -615,9 +615,15 @@ fn test_levelled_only() { let mut dag = unparametrized::Dag::new(); let _ = dag.add_input(22, Shape::number()); let config = default_config(); - let search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); - let sol = - super::optimize_to_circuit_solution(&dag, config, &search_space, &SHARED_CACHES, &None); + let search_space = SearchSpace::default_cpu(); + let sol = super::optimize_to_circuit_solution( + &dag, + config, + &search_space, + &NoSearchSpaceRestriction, + &SHARED_CACHES, + &None, + ); let sol_mono = solo_key::optimize::tests::optimize(&dag) .best_solution .unwrap(); @@ -646,13 +652,14 @@ fn test_big_secret_key_sharing() { key_sharing: false, ..config_sharing }; - let mut search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); + let mut search_space = SearchSpace::default_cpu(); // eprintln!("{:?}", search_space); search_space.glwe_dimensions = vec![1]; // forcing big key sharing let sol_sharing = super::optimize_to_circuit_solution( &dag, config_sharing, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, ); @@ -661,6 +668,7 @@ fn test_big_secret_key_sharing() { &dag, config_no_sharing, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, ); @@ -695,13 +703,14 @@ fn test_big_and_small_secret_key() { key_sharing: false, ..config_sharing }; - let mut search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); + let mut search_space = SearchSpace::default_cpu(); search_space.glwe_dimensions = vec![1]; // forcing big key sharing search_space.internal_lwe_dimensions = vec![768]; // forcing small key sharing let sol_sharing = super::optimize_to_circuit_solution( &dag, config_sharing, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, ); @@ -709,6 +718,7 @@ fn test_big_and_small_secret_key() { &dag, config_no_sharing, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, ); @@ -732,11 +742,12 @@ fn test_composition_2_partitions() { let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let input2 = dag.add_dot([input1, lut3], [1, 1]); let out = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); - let search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); + let search_space = SearchSpace::default_cpu(); let normal_sol = super::optimize( &dag, default_config(), &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, PartitionIndex(1), @@ -748,6 +759,7 @@ fn test_composition_2_partitions() { &dag, default_config(), &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, PartitionIndex(1), @@ -767,11 +779,12 @@ fn test_composition_1_partition_not_composable() { let oup = dag.add_dot([lut1], [1 << 16]); let normal_config = default_config(); let composed_config = normal_config; - let search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); + let search_space = SearchSpace::default_cpu(); let normal_sol = super::optimize( &dag, normal_config, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, PartitionIndex(1), @@ -781,6 +794,7 @@ fn test_composition_1_partition_not_composable() { &dag, composed_config, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &None, PartitionIndex(1), @@ -792,7 +806,7 @@ fn test_composition_1_partition_not_composable() { #[test] fn test_maximal_multi() { let config = default_config(); - let search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); + let search_space = SearchSpace::default_cpu(); let mut dag = unparametrized::Dag::new(); let input = dag.add_input(8, Shape::number()); let lut1 = dag.add_lut(input, FunctionTable::UNKWOWN, 8u8); @@ -808,14 +822,21 @@ fn test_maximal_multi() { eprintln!("{:?}", sol.micro_params.pbs); - let sol_ref = - super::optimize_to_circuit_solution(&dag, config, &search_space, &SHARED_CACHES, &None); + let sol_ref = super::optimize_to_circuit_solution( + &dag, + config, + &search_space, + &NoSearchSpaceRestriction, + &SHARED_CACHES, + &None, + ); assert!(sol_ref.circuit_keys.secret_keys.len() == 2); let sol = super::optimize_to_circuit_solution( &dag, config, &search_space, + &NoSearchSpaceRestriction, &SHARED_CACHES, &Some(p_cut), ); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs index e2c84eb271..22ef2fb4a1 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs @@ -7,6 +7,7 @@ use crate::optimization::dag::solo_key::optimize_generic::{max_precision, Encodi use crate::optimization::decomposition::PersistDecompCaches; use crate::optimization::wop_atomic_pattern::optimize::optimize_to_circuit_solution as crt_optimize_no_dag; +use super::optimize::SearchSpaceRestriction; use super::partition_cut::PartitionCut; fn best_complexity_solution(native: CircuitSolution, crt: CircuitSolution) -> CircuitSolution { @@ -57,13 +58,23 @@ pub fn optimize( dag: &Dag, config: Config, search_space: &SearchSpace, + search_space_restriction: &impl SearchSpaceRestriction, encoding: Encoding, default_log_norm2_woppbs: f64, caches: &PersistDecompCaches, p_cut: &Option, ) -> CircuitSolution { let dag = dag.clone(); - let native = || native_optimize(&dag, config, search_space, caches, p_cut); + let native = || { + native_optimize( + &dag, + config, + search_space, + search_space_restriction, + caches, + p_cut, + ) + }; let crt = || crt_optimize(&dag, config, search_space, default_log_norm2_woppbs, caches); match encoding { Encoding::Auto => best_complexity_solution(native(), crt()), diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 2a74924758..265c45b688 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -428,7 +428,6 @@ pub(crate) mod tests { use crate::computing_cost::cpu::CpuComplexity; use crate::config; use crate::dag::operator::{FunctionTable, Precision, Shape, Weights}; - use crate::global_parameters::DEFAULT_DOMAINS; use crate::noise_estimator::p_error::repeat_p_error; use crate::optimization::config::SearchSpace; use crate::optimization::{atomic_pattern, decomposition}; @@ -480,7 +479,7 @@ pub(crate) mod tests { complexity_model: &CpuComplexity::default(), }; - let search_space = SearchSpace::default_cpu(DEFAULT_DOMAINS); + let search_space = SearchSpace::default_cpu(); super::optimize(dag, config, &search_space, &SHARED_CACHES) } @@ -512,7 +511,7 @@ pub(crate) mod tests { fn v0_parameter_ref(precision: u64, weight: u64, times: &mut Times) { let processing_unit = config::ProcessingUnit::Cpu; - let search_space = SearchSpace::default(processing_unit, DEFAULT_DOMAINS); + let search_space = SearchSpace::default(processing_unit); let sum_size = 1; @@ -610,7 +609,7 @@ pub(crate) mod tests { assert_f64_eq(square(weight) as f64, constraint.pareto_in_lut[0].lut_coeff); } - let search_space = SearchSpace::default(processing_unit, DEFAULT_DOMAINS); + let search_space = SearchSpace::default(processing_unit); let config = Config { security_level, diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 90c856c9a9..421dda91d6 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Union, get_type_hints import numpy as np +from mlir._mlir_libs._concretelang._compiler import KeysetRestriction, RangeRestriction from ..dtypes import Integer from ..representation import GraphProcessor @@ -994,6 +995,8 @@ class Configuration: dynamic_assignment_check_out_of_bounds: bool simulate_encrypt_run_decrypt: bool composable: bool + range_restriction: Optional[RangeRestriction] + keyset_restriction: Optional[KeysetRestriction] def __init__( self, @@ -1063,6 +1066,8 @@ def __init__( dynamic_indexing_check_out_of_bounds: bool = True, dynamic_assignment_check_out_of_bounds: bool = True, simulate_encrypt_run_decrypt: bool = False, + range_restriction: Optional[RangeRestriction] = None, + keyset_restriction: Optional[KeysetRestriction] = None, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1169,6 +1174,8 @@ def __init__( self.dynamic_assignment_check_out_of_bounds = dynamic_assignment_check_out_of_bounds self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt + self.range_restriction = range_restriction + self.keyset_restriction = keyset_restriction self._validate() @@ -1245,6 +1252,8 @@ def fork( dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP, dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP, simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP, + range_restriction: Union[Keep, Optional[RangeRestriction]] = KEEP, + keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 81430f93e6..05f747fd5f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -180,6 +180,11 @@ def create( options.set_enable_tlu_fusing(configuration.enable_tlu_fusing) options.set_print_tlu_fusing(configuration.print_tlu_fusing) + if configuration.keyset_restriction: + options.set_keyset_restriction(configuration.keyset_restriction) + + if configuration.range_restriction: + options.set_range_restriction(configuration.range_restriction) try: if configuration.compiler_debug_mode: # pragma: no cover diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index bff25a89c8..02e4d3f85b 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -223,7 +223,7 @@ def keygen_with_initial_keys( continue key_buffer = input_idx_to_key_buffer[input_idx] - param = self.circuit.client.specs.program_info.secret_keys()[key_id] + param = self.circuit.client.specs.program_info.get_keyset_info().secret_keys()[key_id] try: initial_keys[key_id] = LweSecretKey.deserialize(key_buffer, param) except Exception as e: # pragma: no cover diff --git a/frontends/concrete-python/tests/compilation/test_restrictions.py b/frontends/concrete-python/tests/compilation/test_restrictions.py new file mode 100644 index 0000000000..de96cb85b1 --- /dev/null +++ b/frontends/concrete-python/tests/compilation/test_restrictions.py @@ -0,0 +1,98 @@ +""" +Tests of everything related to restrictions. +""" + +import numpy as np +import pytest +from mlir._mlir_libs._concretelang._compiler import KeysetRestriction, RangeRestriction + +from concrete import fhe + +# pylint: disable=missing-class-docstring, missing-function-docstring, no-self-argument, unused-variable, no-member, unused-argument, function-redefined, expression-not-assigned +# same disables for ruff: +# ruff: noqa: N805, E501, F841, ARG002, F811, B015 + + +def test_range_restriction(): + """ + Test that compiling a module works. + """ + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 20 + + inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] + range_restriction = RangeRestriction() + internal_lwe_dimension = 999 + range_restriction.add_available_internal_lwe_dimension(internal_lwe_dimension) + glwe_log_polynomial_size = 12 + range_restriction.add_available_glwe_log_polynomial_size(glwe_log_polynomial_size) + glwe_dimension = 2 + range_restriction.add_available_glwe_dimension(glwe_dimension) + pbs_level_count = 3 + range_restriction.add_available_pbs_level_count(pbs_level_count) + pbs_base_log = 11 + range_restriction.add_available_pbs_base_log(pbs_base_log) + ks_level_count = 3 + range_restriction.add_available_ks_level_count(ks_level_count) + ks_base_log = 6 + range_restriction.add_available_ks_base_log(ks_base_log) + module = Module.compile( + {"inc": inputset}, enable_unsafe_features=True, range_restriction=range_restriction + ) + keyset_info = module.keys.specs.program_info.get_keyset_info() + assert keyset_info.bootstrap_keys()[0].polynomial_size() == 2**glwe_log_polynomial_size + assert keyset_info.bootstrap_keys()[0].input_lwe_dimension() == internal_lwe_dimension + assert keyset_info.bootstrap_keys()[0].glwe_dimension() == glwe_dimension + assert keyset_info.bootstrap_keys()[0].level() == pbs_level_count + assert keyset_info.bootstrap_keys()[0].base_log() == pbs_base_log + assert keyset_info.keyswitch_keys()[0].level() == ks_level_count + assert keyset_info.keyswitch_keys()[0].base_log() == ks_base_log + assert keyset_info.secret_keys()[0].dimension() == 2**glwe_log_polynomial_size * glwe_dimension + assert keyset_info.secret_keys()[1].dimension() == internal_lwe_dimension + + +def test_keyset_restriction(): + """ + Test that compiling a module works. + """ + + @fhe.module() + class Big: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 200 + + big_inputset = [np.random.randint(1, 200, size=()) for _ in range(100)] + + @fhe.module() + class Small: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 20 + + small_inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] + + big_module = Big.compile( + {"inc": big_inputset}, + enable_unsafe_features=True, + ) + big_keyset_info = big_module.keys.specs.program_info.get_keyset_info() + + small_module = Small.compile( + {"inc": small_inputset}, + enable_unsafe_features=True, + ) + small_keyset_info = small_module.keys.specs.program_info.get_keyset_info() + assert big_keyset_info != small_keyset_info + + restriction = big_keyset_info.get_restriction() + restricted_module = Small.compile( + {"inc": small_inputset}, enable_unsafe_features=True, keyset_restriction=restriction + ) + restricted_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info() + assert big_keyset_info == restricted_keyset_info + assert small_keyset_info != restricted_keyset_info