From d033102a3c333ece84a91a000d6229871362b4f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Tue, 23 Apr 2024 17:19:14 +0200 Subject: [PATCH] feat(optimizer): allow circuit manipulation in optimizer dag --- .../Dialect/FHE/Analysis/ConcreteOptimizer.h | 4 +- .../include/concretelang/Support/Pipeline.h | 2 +- .../concretelang/Support/V0Parameters.h | 3 +- .../FHE/Analysis/ConcreteOptimizer.cpp | 246 ++++--- .../compiler/lib/Support/CompilerEngine.cpp | 36 +- .../compiler/lib/Support/Pipeline.cpp | 28 +- .../TFHEToConcrete/add_glwe_int.mlir | 4 +- .../TFHEToConcrete/mul_glwe_int.mlir | 3 +- .../TFHEToConcrete/sub_int_glwe.mlir | 4 +- .../FHE/Transform/boolean_transforms.mlir | 12 +- .../Dialect/TFHE/optimization.mlir | 6 +- .../src/concrete-optimizer.rs | 238 +++---- .../src/cpp/concrete-optimizer.cpp | 191 ++++-- .../src/cpp/concrete-optimizer.hpp | 41 +- .../concrete-optimizer-cpp/tests/src/main.cpp | 83 +-- .../src/dag/operator/operator.rs | 27 +- .../src/dag/rewrite/regen.rs | 23 +- .../src/dag/rewrite/round.rs | 10 +- .../src/dag/unparametrized.rs | 620 +++++++++++++----- .../dag/multi_parameters/analyze.rs | 138 ++-- .../optimization/dag/multi_parameters/mod.rs | 1 - .../dag/multi_parameters/optimize/mod.rs | 6 +- .../dag/multi_parameters/optimize/tests.rs | 64 +- .../dag/multi_parameters/optimize_generic.rs | 6 +- .../dag/multi_parameters/partition_cut.rs | 34 +- .../dag/multi_parameters/partitionning.rs | 98 ++- .../dag/multi_parameters/symbolic_variance.rs | 4 +- .../dag/multi_parameters/visualization.rs | 162 ----- .../src/optimization/dag/solo_key/analyze.rs | 237 +++---- .../src/optimization/dag/solo_key/optimize.rs | 57 +- .../dag/solo_key/optimize_generic.rs | 8 +- .../concrete-optimizer/src/utils/mod.rs | 1 + .../concrete-optimizer/src/utils/viz.rs | 174 +++++ 33 files changed, 1439 insertions(+), 1132 deletions(-) delete mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h index f279830ab5..dd0beae315 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h @@ -17,10 +17,8 @@ namespace mlir { namespace concretelang { namespace optimizer { -using FunctionsDag = std::map>; - std::unique_ptr createDagPass(optimizer::Config config, - optimizer::FunctionsDag &dags); + concrete_optimizer::Dag &dag); } // namespace optimizer } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index 0c787c9a01..49735900d8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -24,7 +24,7 @@ mlir::LogicalResult materializeOptimizerPartitionFrontiers( std::optional &fheContext, std::function enablePass); -llvm::Expected>> +llvm::Expected> getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, optimizer::Config config, std::function enablePass); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h index 83d90e9aa2..24dd6153e9 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h @@ -135,7 +135,8 @@ constexpr Config DEFAULT_CONFIG = { DEFAULT_COMPOSABLE, }; -using Dag = rust::Box; +using Dag = rust::Box; +using DagBuilder = rust::Box; using DagSolution = concrete_optimizer::dag::DagSolution; using CircuitSolution = concrete_optimizer::dag::CircuitSolution; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index d2d3a95acc..150a65359e 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -58,9 +58,11 @@ struct FunctionToDag { optimizer::Config config; llvm::DenseMap index; bool setOptimizerID; + concrete_optimizer::DagBuilder &dagBuilder; - FunctionToDag(mlir::func::FuncOp func, optimizer::Config config) - : func(func), config(config) { + FunctionToDag(mlir::func::FuncOp func, optimizer::Config config, + concrete_optimizer::DagBuilder &dagBuilder) + : func(func), config(config), dagBuilder(dagBuilder) { setOptimizerID = config.strategy == optimizer::Strategy::DAG_MULTI; } @@ -69,15 +71,12 @@ struct FunctionToDag { mlir::concretelang::log_verbose() << MSG << "\n"; \ } - outcome::checked, - ::concretelang::error::StringError> - build() { - auto dag = concrete_optimizer::dag::empty(); + void build() { // Converting arguments as Input mlir::Builder builder(func.getContext()); for (size_t i = 0; i < func.getNumArguments(); i++) { auto arg = func.getArgument(i); - auto optimizerIdx = addArg(dag, arg); + auto optimizerIdx = addArg(arg); if (optimizerIdx.has_value() && setOptimizerID) { func.setArgAttr(i, "TFHE.OId", builder.getI32IntegerAttr(optimizerIdx->index)); @@ -86,7 +85,7 @@ struct FunctionToDag { // Converting ops for (auto &bb : func.getBody().getBlocks()) { for (auto &op : bb.getOperations()) { - addOperation(dag, op); + addOperation(op); } } for (auto &bb : func.getBody().getBlocks()) { @@ -98,21 +97,20 @@ struct FunctionToDag { // Dag is empty <=> classical function without encryption DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName() << "\n"); - return std::nullopt; + return; }; - DEBUG(std::string(dag->dump())); - return std::move(dag); + DEBUG(std::string(dagBuilder.dump())); } std::optional - addArg(optimizer::Dag &dag, mlir::Value &arg) { + addArg(mlir::Value &arg) { DEBUG("Arg " << arg << " " << arg.getType()); if (!fhe::utils::isEncryptedValue(arg)) { return std::nullopt; } auto precision = fhe::utils::getEintPrecision(arg); auto shape = getShape(arg); - auto opI = dag->add_input(precision, slice(shape)); + auto opI = dagBuilder.add_input(precision, slice(shape)); index[arg] = opI; return opI; } @@ -126,13 +124,13 @@ struct FunctionToDag { return false; } - void addOperation(optimizer::Dag &dag, mlir::Operation &op) { + void addOperation(mlir::Operation &op) { DEBUG("Instr " << op); auto encrypted_inputs = encryptedInputs(op); if (isReturn(op)) { for (auto op : encrypted_inputs) { - dag->tag_operator_as_output(op); + dagBuilder.tag_operator_as_output(op); } return; } @@ -148,61 +146,60 @@ struct FunctionToDag { auto precision = fhe::utils::getEintPrecision(val); concrete_optimizer::dag::OperatorIndex index; if (auto inputType = isLut(op); inputType != nullptr) { - addLut(dag, op, inputType, encrypted_inputs, precision); + addLut(op, inputType, encrypted_inputs, precision); return; } else if (isRound(op)) { - index = addRound(dag, val, encrypted_inputs, precision); + index = addRound(val, encrypted_inputs, precision); } else if (isReinterpretPrecision(op)) { - index = addReinterpretPrecision(dag, val, encrypted_inputs, precision); + index = addReinterpretPrecision(val, encrypted_inputs, precision); } else if (auto lsb = asLsb(op)) { - addLsb(dag, lsb, encrypted_inputs); + addLsb(lsb, encrypted_inputs); return; } else if (auto lsb = asLsbTensor(op)) { - addLsb(dag, lsb, encrypted_inputs); + addLsb(lsb, encrypted_inputs); return; } else if (auto dot = asDot(op)) { auto weightsOpt = dotWeights(dot); if (weightsOpt) { - index = addDot(dag, val, encrypted_inputs, weightsOpt.value()); + index = addDot(val, encrypted_inputs, weightsOpt.value()); } else { // If can't find weights return default leveled op DEBUG("Replace Dot by LevelledOp on " << op); - index = addLevelledOp(dag, op, encrypted_inputs); + index = addLevelledOp(op, encrypted_inputs); } } else if (auto dot = asDotEint(op)) { - addDotEint(dag, dot, encrypted_inputs, precision); + addDotEint(dot, encrypted_inputs, precision); // The above function call sets the OIds, can return right away return; } else if (auto mul = asMul(op)) { // special case as mul are rewritten in several optimizer nodes - addMul(dag, mul, encrypted_inputs, precision); + addMul(mul, encrypted_inputs, precision); return; } else if (auto mul = asMulTensor(op)) { // special case as mul are rewritten in several optimizer nodes - addMul(dag, mul, encrypted_inputs, precision); + addMul(mul, encrypted_inputs, precision); return; } else if (auto max = asMax(op)) { // special case as max are rewritten in several optimizer nodes - addMax(dag, max, encrypted_inputs, precision); + addMax(max, encrypted_inputs, precision); return; } else if (auto maxpool2d = asMaxpool2d(op)) { // special case as max are rewritten in several optimizer nodes - addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision); + addMaxpool2d(maxpool2d, encrypted_inputs, precision); return; } else if (auto matmulEintEint = asMatmulEintEint(op)) { - addEncMatMulTensor(dag, matmulEintEint, encrypted_inputs, precision); + addEncMatMulTensor(matmulEintEint, encrypted_inputs, precision); return; } else { - index = addLevelledOp(dag, op, encrypted_inputs); + index = addLevelledOp(op, encrypted_inputs); } mlir::Builder builder(op.getContext()); if (setOptimizerID) op.setAttr("TFHE.OId", builder.getI32IntegerAttr(index.index)); } - void addLut(optimizer::Dag &dag, mlir::Operation &op, - FHE::FheIntegerInterface inputType, Inputs &encrypted_inputs, - int precision) { + void addLut(mlir::Operation &op, FHE::FheIntegerInterface inputType, + Inputs &encrypted_inputs, int precision) { auto val = op.getResult(0); assert(encrypted_inputs.size() == 1); // No need to distinguish different lut kind until we do approximate @@ -212,13 +209,13 @@ struct FunctionToDag { std::vector operatorIndexes; if (inputType.isSigned()) { // std::vector weights_vector{1}; - auto addIndex = dag->add_dot(slice(encrypted_inputs), - concrete_optimizer::weights::number(1)); + auto addIndex = dagBuilder.add_dot( + slice(encrypted_inputs), concrete_optimizer::weights::number(1)); encrypted_input = addIndex; operatorIndexes.push_back(addIndex.index); } auto lutIndex = - dag->add_lut(encrypted_input, slice(unknowFunction), precision); + dagBuilder.add_lut(encrypted_input, slice(unknowFunction), precision); operatorIndexes.push_back(lutIndex.index); mlir::Builder builder(op.getContext()); if (setOptimizerID) @@ -226,33 +223,32 @@ struct FunctionToDag { index[val] = lutIndex; } - concrete_optimizer::dag::OperatorIndex addRound(optimizer::Dag &dag, - mlir::Value &val, - Inputs &encrypted_inputs, - int rounded_precision) { + concrete_optimizer::dag::OperatorIndex + addRound(mlir::Value &val, Inputs &encrypted_inputs, int rounded_precision) { assert(encrypted_inputs.size() == 1); // No need to distinguish different lut kind until we do approximate // paradigm on outputs auto encrypted_input = encrypted_inputs[0]; - index[val] = dag->add_round_op(encrypted_input, rounded_precision); + index[val] = dagBuilder.add_round_op(encrypted_input, rounded_precision); return index[val]; } concrete_optimizer::dag::OperatorIndex - addReinterpretPrecision(optimizer::Dag &dag, mlir::Value &val, - Inputs &encrypted_inputs, int new_precision) { + addReinterpretPrecision(mlir::Value &val, Inputs &encrypted_inputs, + int new_precision) { assert(encrypted_inputs.size() == 1); auto encrypted_input = encrypted_inputs[0]; - index[val] = dag->add_unsafe_cast_op(encrypted_input, new_precision); + index[val] = dagBuilder.add_unsafe_cast_op(encrypted_input, new_precision); return index[val]; } concrete_optimizer::dag::OperatorIndex - addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, + addDot(mlir::Value &val, Inputs &encrypted_inputs, std::vector &weights_vector) { assert(encrypted_inputs.size() == 1); auto weights = concrete_optimizer::weights::vector(slice(weights_vector)); - index[val] = dag->add_dot(slice(encrypted_inputs), std::move(weights)); + index[val] = + dagBuilder.add_dot(slice(encrypted_inputs), std::move(weights)); return index[val]; } @@ -263,15 +259,15 @@ struct FunctionToDag { return loc; } - concrete_optimizer::dag::OperatorIndex - addLevelledOp(optimizer::Dag &dag, mlir::Operation &op, Inputs &inputs) { + concrete_optimizer::dag::OperatorIndex addLevelledOp(mlir::Operation &op, + Inputs &inputs) { auto val = op.getResult(0); auto out_shape = getShape(val); if (inputs.empty()) { // Trivial encrypted constants encoding // There are converted to input + levelledop auto precision = fhe::utils::getEintPrecision(val); - auto opI = dag->add_input(precision, slice(out_shape)); + auto opI = dagBuilder.add_input(precision, slice(out_shape)); inputs.push_back(opI); } // Default complexity is negligible @@ -284,8 +280,8 @@ struct FunctionToDag { double manp = sqrt(smanp_int.getValue().roundToDouble()); auto comment = std::string(op.getName().getStringRef()) + " " + loc; index[val] = - dag->add_levelled_op(slice(inputs), lwe_dim_cost_factor, fixed_cost, - manp, slice(out_shape), comment); + dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor, + fixed_cost, manp, slice(out_shape), comment); return index[val]; } @@ -297,22 +293,22 @@ struct FunctionToDag { } template - void addLsb(optimizer::Dag &dag, LsbOp &lsbOp, Inputs &encrypted_inputs) { + void addLsb(LsbOp &lsbOp, Inputs &encrypted_inputs) { assert(encrypted_inputs.size() == 1); auto input = lsbOp.getInput(); auto result = lsbOp.getResult(); auto input_precision = fhe::utils::getEintPrecision(input); auto output_precision = fhe::utils::getEintPrecision(result); - auto lsb_shiffted_as_1bit_wop = - dag->add_dot(slice(encrypted_inputs), - concrete_optimizer::weights::number(1 << input_precision)); + auto lsb_shiffted_as_1bit_wop = dagBuilder.add_dot( + slice(encrypted_inputs), + concrete_optimizer::weights::number(1 << input_precision)); std::vector unknownFunction; auto overflow_bit_precision = 0; - auto lsb_as_0_bits = dag->add_unsafe_cast_op( + auto lsb_as_0_bits = dagBuilder.add_unsafe_cast_op( lsb_shiffted_as_1bit_wop, overflow_bit_precision); // id for rotation - auto lsb_result = - dag->add_lut(lsb_as_0_bits, slice(unknownFunction), output_precision); - auto lsb_result_corrected = idPlaceholder(dag, lsb_result); + auto lsb_result = dagBuilder.add_lut(lsb_as_0_bits, slice(unknownFunction), + output_precision); + auto lsb_result_corrected = idPlaceholder(lsb_result); index[result] = lsb_result_corrected; if (!setOptimizerID) { @@ -333,8 +329,7 @@ struct FunctionToDag { } template - void addMul(optimizer::Dag &dag, MulOp &mulOp, Inputs &inputs, - int precision) { + void addMul(MulOp &mulOp, Inputs &inputs, int precision) { // x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y) @@ -379,39 +374,40 @@ struct FunctionToDag { // tlu(x + y) auto addNode = - dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(resultShape), comment); + dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(resultShape), comment); std::optional lhsCorrectionNode; if (isSignedEint(mulOp.getType())) { // If signed mul we need to add the addition node for correction of the // signed tlu - addNode = dag->add_dot( + addNode = dagBuilder.add_dot( slice(std::vector{addNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); lhsCorrectionNode = addNode; } - auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision); + auto lhsTluNode = + dagBuilder.add_lut(addNode, slice(unknownFunction), precision); // tlu(x - y) auto subNode = - dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(resultShape), comment); + dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(resultShape), comment); // This is a signed tlu so we need to also add the addition for correction // signed tlu - auto rhsCorrectionNode = dag->add_dot( + auto rhsCorrectionNode = dagBuilder.add_dot( slice(std::vector{subNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); - auto rhsTluNode = - dag->add_lut(rhsCorrectionNode, slice(unknownFunction), precision); + auto rhsTluNode = dagBuilder.add_lut(rhsCorrectionNode, + slice(unknownFunction), precision); // tlu(x + y) - tlu(x - y) const std::vector subInputs = { lhsTluNode, rhsTluNode}; - auto resultNode = - dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost, - tluSubManp, slice(resultShape), comment); + auto resultNode = dagBuilder.add_levelled_op( + slice(subInputs), lweDimCostFactor, fixedCost, tluSubManp, + slice(resultShape), comment); index[result] = resultNode; mlir::Builder builder(mulOp.getContext()); @@ -431,8 +427,7 @@ struct FunctionToDag { template concrete_optimizer::dag::OperatorIndex - addTensorInnerProductEncEnc(optimizer::Dag &dag, - InnerProductOp &innerProductOp, Inputs &inputs, + addTensorInnerProductEncEnc(InnerProductOp &innerProductOp, Inputs &inputs, int precision) { mlir::Value result = innerProductOp.getResult(); const std::vector resultShape = getShape(result); @@ -557,39 +552,40 @@ struct FunctionToDag { // tlu(x + y) auto addNode = - dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(pairMatrixShape), comment); + dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(pairMatrixShape), comment); std::optional lhsCorrectionNode; if (isSignedEint(innerProductOp.getType())) { // If signed mul we need to add the addition node for correction of the // signed tlu - addNode = dag->add_dot( + addNode = dagBuilder.add_dot( slice(std::vector{addNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); lhsCorrectionNode = addNode; } - auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision); + auto lhsTluNode = + dagBuilder.add_lut(addNode, slice(unknownFunction), precision); // tlu(x - y) auto subNode = - dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(pairMatrixShape), comment); + dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(pairMatrixShape), comment); // This is a signed tlu so we need to also add the addition for correction // signed tlu - auto rhsCorrectionNode = dag->add_dot( + auto rhsCorrectionNode = dagBuilder.add_dot( slice(std::vector{subNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); - auto rhsTluNode = - dag->add_lut(rhsCorrectionNode, slice(unknownFunction), precision); + auto rhsTluNode = dagBuilder.add_lut(rhsCorrectionNode, + slice(unknownFunction), precision); // tlu(x + y) - tlu(x - y) const std::vector subInputs = { lhsTluNode, rhsTluNode}; - auto resultNode = - dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost, - tluSubManp, slice(pairMatrixShape), comment); + auto resultNode = dagBuilder.add_levelled_op( + slice(subInputs), lweDimCostFactor, fixedCost, tluSubManp, + slice(pairMatrixShape), comment); // 3. Sum(tlu(x + y) - tlu(x - y)) // Create a leveled op that simulates concatenation. It takes @@ -610,9 +606,9 @@ struct FunctionToDag { // TODO: use APIFloat.sqrt when it's available double manp = sqrt(smanp_int.getValue().roundToDouble()); - index[result] = - dag->add_levelled_op(slice(sumOperands), lwe_dim_cost_factor, - fixed_cost, manp, slice(resultShape), comment); + index[result] = dagBuilder.add_levelled_op( + slice(sumOperands), lwe_dim_cost_factor, fixed_cost, manp, + slice(resultShape), comment); // Create the TFHE.OId attributes // The first elements of the vector are nodes for the encrypted @@ -637,22 +633,19 @@ struct FunctionToDag { } concrete_optimizer::dag::OperatorIndex - addEncMatMulTensor(optimizer::Dag &dag, FHELinalg::MatMulEintEintOp &matmulOp, - Inputs &inputs, int precision) { + addEncMatMulTensor(FHELinalg::MatMulEintEintOp &matmulOp, Inputs &inputs, + int precision) { return addTensorInnerProductEncEnc( - dag, matmulOp, inputs, precision); + matmulOp, inputs, precision); } - concrete_optimizer::dag::OperatorIndex addDotEint(optimizer::Dag &dag, - FHELinalg::DotEint &dotOp, - Inputs &inputs, - int precision) { - return addTensorInnerProductEncEnc(dag, dotOp, inputs, + concrete_optimizer::dag::OperatorIndex + addDotEint(FHELinalg::DotEint &dotOp, Inputs &inputs, int precision) { + return addTensorInnerProductEncEnc(dotOp, inputs, precision); } - void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs, - int precision) { + void addMax(FHE::MaxEintOp &maxOp, Inputs &inputs, int precision) { mlir::Value result = maxOp.getResult(); const std::vector resultShape = getShape(result); @@ -683,19 +676,20 @@ struct FunctionToDag { auto comment = std::string(maxOp->getName().getStringRef()) + " " + loc; auto subNode = - dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - subManp, slice(resultShape), comment); + dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + subManp, slice(resultShape), comment); const double tluNodeManp = 1; const std::vector unknownFunction; - auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision); + auto tluNode = + dagBuilder.add_lut(subNode, slice(unknownFunction), precision); const double addManp = sqrt(tluNodeManp + ySmanp.roundToDouble()); const std::vector addInputs = { tluNode, inputs[1]}; - auto resultNode = - dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost, - addManp, slice(resultShape), comment); + auto resultNode = dagBuilder.add_levelled_op( + slice(addInputs), lweDimCostFactor, fixedCost, addManp, + slice(resultShape), comment); index[result] = resultNode; // Set attribute on the MLIR node @@ -707,8 +701,8 @@ struct FunctionToDag { maxOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes)); } - void addMaxpool2d(optimizer::Dag &dag, FHELinalg::Maxpool2dOp &maxpool2dOp, - Inputs &inputs, int precision) { + void addMaxpool2d(FHELinalg::Maxpool2dOp &maxpool2dOp, Inputs &inputs, + int precision) { mlir::Value result = maxpool2dOp.getResult(); const std::vector resultShape = getShape(result); @@ -743,19 +737,20 @@ struct FunctionToDag { std::string(maxpool2dOp->getName().getStringRef()) + " " + loc; auto subNode = - dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - subManp, slice(fakeShape), comment); + dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + subManp, slice(fakeShape), comment); const std::vector unknownFunction; - auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision); + auto tluNode = + dagBuilder.add_lut(subNode, slice(unknownFunction), precision); const double addManp = sqrt(inputSmanp.roundToDouble() + 1); const std::vector addInputs = { tluNode, inputs[0]}; - auto resultNode = - dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost, - addManp, slice(resultShape), comment); + auto resultNode = dagBuilder.add_levelled_op( + slice(addInputs), lweDimCostFactor, fixedCost, addManp, + slice(resultShape), comment); index[result] = resultNode; // Set attribute on the MLIR node mlir::Builder builder(maxpool2dOp.getContext()); @@ -773,10 +768,10 @@ struct FunctionToDag { } concrete_optimizer::dag::OperatorIndex - idPlaceholder(optimizer::Dag &dag, - concrete_optimizer::dag::OperatorIndex input) { + idPlaceholder(concrete_optimizer::dag::OperatorIndex input) { std::vector inputs = {input}; - return dag->add_dot(slice(inputs), concrete_optimizer::weights::number(1)); + return dagBuilder.add_dot(slice(inputs), + concrete_optimizer::weights::number(1)); } Inputs encryptedInputs(mlir::Operation &op) { @@ -953,24 +948,19 @@ struct FunctionToDag { struct DagPass : ConcreteOptimizerBase { optimizer::Config config; - optimizer::FunctionsDag &dags; + concrete_optimizer::Dag &dag; void runOnOperation() override { mlir::func::FuncOp func = getOperation(); auto name = std::string(func.getName()); DEBUG("ConcreteOptimizer Dag: " << name); - auto dag = FunctionToDag(func, config).build(); - if (dag) { - dags.insert( - optimizer::FunctionsDag::value_type(name, std::move(dag.value()))); - } else { - this->signalPassFailure(); - } + auto builder = dag.builder(name); + FunctionToDag(func, config, *builder).build(); } DagPass() = delete; - DagPass(optimizer::Config config, optimizer::FunctionsDag &dags) - : config(config), dags(dags) {} + DagPass(optimizer::Config config, concrete_optimizer::Dag &dag) + : config(config), dag(dag) {} }; // Create an instance of the ConcreteOptimizerPass pass. @@ -979,8 +969,8 @@ struct DagPass : ConcreteOptimizerBase { // remark containing the squared Minimal Arithmetic Noise Padding of // the equivalent dot operation. std::unique_ptr createDagPass(optimizer::Config config, - optimizer::FunctionsDag &dags) { - return std::make_unique(config, dags); + concrete_optimizer::Dag &dag) { + return std::make_unique(config, dag); } } // namespace optimizer diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index b5a258ee40..af29041ff1 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -161,6 +162,7 @@ llvm::Expected> CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) { mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); mlir::ModuleOp module = res.mlirModuleRef->get(); + auto config = this->compilerOptions.optimizerConfig; // If the values has been overwritten returns if (this->overrideMaxEintPrecision.has_value() && this->overrideMaxMANP.has_value()) { @@ -168,47 +170,23 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) { this->overrideMaxMANP.value(), this->overrideMaxEintPrecision.value()}; return optimizer::Description{constraint, std::nullopt}; } - auto config = this->compilerOptions.optimizerConfig; - auto descriptions = mlir::concretelang::pipeline::getFHEContextFromFHE( + auto description = mlir::concretelang::pipeline::getFHEContextFromFHE( mlirContext, module, config, enablePass); - if (auto err = descriptions.takeError()) { + if (auto err = description.takeError()) { return std::move(err); } - if (descriptions->empty()) { // The pass has not been run + if (!description->has_value()) { // The pass has not been run return std::nullopt; } - if (descriptions->size() > 1 && + if (description->value().dag.value()->get_circuit_count() > 1 && config.strategy != mlir::concretelang::optimizer::V0) { // Multi circuits without V0 return StreamStringError( "Multi-circuits is only supported for V0 optimization."); } - if (descriptions->size() > 1) { - auto iter = descriptions->begin(); - auto desc = std::move(iter->second); - if (!desc.has_value()) { - return StreamStringError("Expected description."); - } - if (!desc.value().dag.has_value()) { - return StreamStringError("Expected dag in description."); - } - iter++; - while (iter != descriptions->end()) { - if (!iter->second.has_value()) { - return StreamStringError("Expected description."); - } - if (!iter->second.value().dag.has_value()) { - return StreamStringError("Expected dag in description."); - } - desc->dag.value()->concat(*iter->second.value().dag.value()); - iter++; - } - return std::move(desc); - } - return std::move(descriptions->begin()->second); + return description; } -/// set the fheContext field if the v0Constraint can be computed /// set the fheContext field if the v0Constraint can be computed llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { if (compilerOptions.v0Parameter.has_value()) { diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 210d2ea03d..8dbfd97966 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -5,6 +5,7 @@ #include "llvm/Support/TargetSelect.h" +#include "concrete-optimizer.hpp" #include "concretelang/Support/CompilationFeedback.h" #include "concretelang/Support/V0Parameters.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" @@ -13,6 +14,7 @@ #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" #include "llvm/Support/Error.h" +#include #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" @@ -89,13 +91,12 @@ addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr pass, } } -llvm::Expected>> +llvm::Expected> getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, optimizer::Config config, std::function enablePass) { std::optional oMax2norm; std::optional oMaxWidth; - optimizer::FunctionsDag dags; mlir::PassManager pm(&context); @@ -127,26 +128,19 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, {/*.norm2 = */ ceilLog2(oMax2norm.value()), /*.p = */ oMaxWidth.value()}); } - addPotentiallyNestedPass(pm, optimizer::createDagPass(config, dags), + auto dag = concrete_optimizer::dag::empty(); + addPotentiallyNestedPass(pm, optimizer::createDagPass(config, *dag), enablePass); if (pm.run(module.getOperation()).failed()) { return StreamStringError() << "Failed to create concrete-optimizer dag\n"; } - std::map> descriptions; - for (auto &entry_dag : dags) { - if (!constraint) { - descriptions.insert( - decltype(descriptions)::value_type(entry_dag.first, std::nullopt)); - continue; - } - optimizer::Description description = {*constraint, - std::move(entry_dag.second)}; - std::optional opt_description{ - std::move(description)}; - descriptions.insert(decltype(descriptions)::value_type( - entry_dag.first, std::move(opt_description))); + std::optional description; + if (!constraint) { + description = std::nullopt; + } else { + description = {*constraint, std::move(dag)}; } - return std::move(descriptions); + return std::move(description); } mlir::LogicalResult materializeOptimizerPartitionFrontiers( diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir index da24a12983..cf9c76aaf2 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir @@ -1,4 +1,5 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s + //CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 @@ -11,6 +12,7 @@ func.func @add_glwe_const_int(%arg0: !TFHE.glwe>) -> !TFHE.glwe> } +// ----- //CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { //CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir index ac9e5607ea..bc801dc995 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 @@ -11,6 +11,7 @@ func.func @mul_glwe_const_int(%arg0: !TFHE.glwe>) -> !TFHE.glwe> } +// ----- //CHECK: func.func @mul_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { //CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir index 64ae8cef43..a4499768e1 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 @@ -12,6 +12,8 @@ func.func @sub_const_int_glwe(%arg0: !TFHE.glwe>) -> !TFHE.glwe> } +// ----- + //CHECK: func.func @sub_int_glwe(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> { //CHECK: %[[V0:.*]] = "Concrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64> //CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[V0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir index a47a7bee55..af5426ca09 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes fhe-boolean-transform --action=dump-fhe %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes fhe-boolean-transform --action=dump-fhe --split-input-file %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool { @@ -15,6 +15,8 @@ func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) return %1: !FHE.ebool } +// ----- + // CHECK-LABEL: func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64> @@ -31,6 +33,8 @@ func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { return %1: !FHE.ebool } +// ----- + // CHECK-LABEL: func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 1]> : tensor<4xi64> @@ -47,6 +51,8 @@ func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { return %1: !FHE.ebool } +// ----- + // CHECK-LABEL: func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[1, 1, 1, 0]> : tensor<4xi64> @@ -63,6 +69,8 @@ func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { return %1: !FHE.ebool } +// ----- + // CHECK-LABEL: func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 0]> : tensor<4xi64> @@ -79,6 +87,8 @@ func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { return %1: !FHE.ebool } +// ----- + // CHECK-LABEL: func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool { // CHECK-NEXT: %[[TT1:.*]] = arith.constant dense<[0, 0, 1, 0]> : tensor<4xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir index ef4edbfbb3..2f910f958d 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe --split-input-file --skip-program-info %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe>, %arg1: i64) -> !TFHE.glwe> @@ -10,6 +10,8 @@ func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe>, %arg1: return %1: !TFHE.glwe> } +// ----- + // CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe>) -> !TFHE.glwe> func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { // CHECK-NEXT: %[[V1:.*]] = "TFHE.zero"() : () -> !TFHE.glwe> @@ -20,6 +22,8 @@ func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe>) -> !T return %2: !TFHE.glwe> } +// ----- + // CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !TFHE.glwe>) -> !TFHE.glwe> func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { // CHECK-NEXT: %[[V1:.*]] = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe>) -> !TFHE.glwe> diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index 1e61cc33e0..0193bb302a 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -15,6 +15,7 @@ use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{ use concrete_optimizer::optimization::decomposition; use concrete_optimizer::parameters::{BrDecompositionParameters, KsDecompositionParameters}; use concrete_optimizer::utils::cache::persistent::default_cache_dir; +use concrete_optimizer::utils::viz::Viz; fn no_solution() -> ffi::Solution { ffi::Solution { @@ -221,7 +222,7 @@ impl From for ffi::DagSolution { } } -fn convert_to_circuit_solution(sol: &ffi::DagSolution, dag: &OperationDag) -> ffi::CircuitSolution { +fn convert_to_circuit_solution(sol: &ffi::DagSolution, dag: &Dag) -> ffi::CircuitSolution { let big_key = ffi::SecretLweKey { identifier: 0, polynomial_size: sol.glwe_polynomial_size, @@ -473,84 +474,19 @@ fn NO_KEY_ID() -> u64 { keys_spec::NO_KEY_ID } -pub struct OperationDag(unparametrized::OperationDag); +pub struct Dag(unparametrized::Dag); -fn empty() -> Box { - Box::new(OperationDag(unparametrized::OperationDag::new())) +fn empty() -> Box { + Box::new(Dag(unparametrized::Dag::new())) } -impl OperationDag { - fn add_input(&mut self, out_precision: Precision, out_shape: &[u64]) -> ffi::OperatorIndex { - let out_shape = Shape { - dimensions_size: out_shape.to_owned(), - }; - - self.0.add_input(out_precision, out_shape).into() - } - - fn add_lut( - &mut self, - input: ffi::OperatorIndex, - table: &[u64], - out_precision: Precision, - ) -> ffi::OperatorIndex { - let table = FunctionTable { - values: table.to_owned(), - }; - - self.0.add_lut(input.into(), table, out_precision).into() - } - - #[allow(clippy::boxed_local)] - fn add_dot( - &mut self, - inputs: &[ffi::OperatorIndex], - weights: Box, - ) -> ffi::OperatorIndex { - let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); - - self.0.add_dot(inputs, weights.0).into() - } - - fn add_levelled_op( - &mut self, - inputs: &[ffi::OperatorIndex], - lwe_dim_cost_factor: f64, - fixed_cost: f64, - manp: f64, - out_shape: &[u64], - comment: &str, - ) -> ffi::OperatorIndex { - let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); - - let out_shape = Shape { - dimensions_size: out_shape.to_owned(), - }; - - let complexity = LevelledComplexity { - lwe_dim_cost_factor, - fixed_cost, - }; - - self.0 - .add_levelled_op(inputs, complexity, manp, out_shape, comment) - .into() +impl Dag { + fn builder(&mut self, circuit: String) -> Box> { + Box::new(DagBuilder(self.0.builder(circuit))) } - fn add_round_op( - &mut self, - input: ffi::OperatorIndex, - rounded_precision: Precision, - ) -> ffi::OperatorIndex { - self.0.add_round_op(input.into(), rounded_precision).into() - } - - fn add_unsafe_cast_op( - &mut self, - input: ffi::OperatorIndex, - new_precision: Precision, - ) -> ffi::OperatorIndex { - self.0.add_unsafe_cast(input.into(), new_precision).into() + fn dump(&self) -> String { + self.0.viz_string() } fn optimize(&self, options: ffi::Options) -> ffi::DagSolution { @@ -595,16 +531,8 @@ impl OperationDag { } } - fn dump(&self) -> String { - self.0.dump() - } - - fn concat(&mut self, other: &Self) { - self.0.concat(&other.0); - } - - fn tag_operator_as_output(&mut self, op: ffi::OperatorIndex) { - self.0.tag_operator_as_output(op.into()); + fn get_circuit_count(&self) -> usize { + self.0.get_circuit_count() } fn optimize_multi(&self, options: ffi::Options) -> ffi::CircuitSolution { @@ -642,6 +570,91 @@ impl OperationDag { } } +pub struct DagBuilder<'dag>(unparametrized::DagBuilder<'dag>); + +impl<'dag> DagBuilder<'dag> { + fn add_input(&mut self, out_precision: Precision, out_shape: &[u64]) -> ffi::OperatorIndex { + let out_shape = Shape { + dimensions_size: out_shape.to_owned(), + }; + + self.0.add_input(out_precision, out_shape).into() + } + + fn add_lut( + &mut self, + input: ffi::OperatorIndex, + table: &[u64], + out_precision: Precision, + ) -> ffi::OperatorIndex { + let table = FunctionTable { + values: table.to_owned(), + }; + + self.0.add_lut(input.into(), table, out_precision).into() + } + + #[allow(clippy::boxed_local)] + fn add_dot( + &mut self, + inputs: &[ffi::OperatorIndex], + weights: Box, + ) -> ffi::OperatorIndex { + let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); + + self.0.add_dot(inputs, weights.0).into() + } + + fn add_levelled_op( + &mut self, + inputs: &[ffi::OperatorIndex], + lwe_dim_cost_factor: f64, + fixed_cost: f64, + manp: f64, + out_shape: &[u64], + comment: &str, + ) -> ffi::OperatorIndex { + let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); + + let out_shape = Shape { + dimensions_size: out_shape.to_owned(), + }; + + let complexity = LevelledComplexity { + lwe_dim_cost_factor, + fixed_cost, + }; + + self.0 + .add_levelled_op(inputs, complexity, manp, out_shape, comment) + .into() + } + + fn add_round_op( + &mut self, + input: ffi::OperatorIndex, + rounded_precision: Precision, + ) -> ffi::OperatorIndex { + self.0.add_round_op(input.into(), rounded_precision).into() + } + + fn add_unsafe_cast_op( + &mut self, + input: ffi::OperatorIndex, + new_precision: Precision, + ) -> ffi::OperatorIndex { + self.0.add_unsafe_cast(input.into(), new_precision).into() + } + + fn tag_operator_as_output(&mut self, op: ffi::OperatorIndex) { + self.0.tag_operator_as_output(op.into()); + } + + fn dump(&self) -> String { + format!("{}", self.0.get_circuit()) + } +} + pub struct Weights(operator::Weights); fn vector(weights: &[i64]) -> Box { @@ -654,14 +667,14 @@ fn number(weight: i64) -> Box { impl From for ffi::OperatorIndex { fn from(oi: OperatorIndex) -> Self { - Self { index: oi.i } + Self { index: oi.0 } } } #[allow(clippy::from_over_into)] impl Into for ffi::OperatorIndex { fn into(self) -> OperatorIndex { - OperatorIndex { i: self.index } + OperatorIndex(self.index) } } @@ -677,7 +690,7 @@ impl Into for ffi::Encoding { } } -#[allow(unused_must_use)] +#[allow(unused_must_use, clippy::needless_lifetimes)] #[cxx::bridge] mod ffi { #[namespace = "concrete_optimizer"] @@ -690,37 +703,42 @@ mod ffi { fn convert_to_dag_solution(solution: &Solution) -> DagSolution; #[namespace = "concrete_optimizer::utils"] - fn convert_to_circuit_solution( - solution: &DagSolution, - dag: &OperationDag, - ) -> CircuitSolution; + fn convert_to_circuit_solution(solution: &DagSolution, dag: &Dag) -> CircuitSolution; - type OperationDag; + type Dag; + + type DagBuilder<'dag>; #[namespace = "concrete_optimizer::dag"] - fn empty() -> Box; + fn empty() -> Box; + + unsafe fn builder(self: &mut Dag, circuit: String) -> Box>; + + fn dump(self: &Dag) -> String; - fn add_input( - self: &mut OperationDag, + fn dump(self: &DagBuilder) -> String; + + unsafe fn add_input( + self: &mut DagBuilder<'_>, out_precision: u8, out_shape: &[u64], ) -> OperatorIndex; - fn add_lut( - self: &mut OperationDag, + unsafe fn add_lut( + self: &mut DagBuilder<'_>, input: OperatorIndex, table: &[u64], out_precision: u8, ) -> OperatorIndex; - fn add_dot( - self: &mut OperationDag, + unsafe fn add_dot( + self: &mut DagBuilder<'_>, inputs: &[OperatorIndex], weights: Box, ) -> OperatorIndex; - fn add_levelled_op( - self: &mut OperationDag, + unsafe fn add_levelled_op( + self: &mut DagBuilder<'_>, inputs: &[OperatorIndex], lwe_dim_cost_factor: f64, fixed_cost: f64, @@ -729,23 +747,21 @@ mod ffi { comment: &str, ) -> OperatorIndex; - fn add_round_op( - self: &mut OperationDag, + unsafe fn add_round_op( + self: &mut DagBuilder<'_>, input: OperatorIndex, rounded_precision: u8, ) -> OperatorIndex; - fn add_unsafe_cast_op( - self: &mut OperationDag, + unsafe fn add_unsafe_cast_op( + self: &mut DagBuilder<'_>, input: OperatorIndex, rounded_precision: u8, ) -> OperatorIndex; - fn optimize(self: &OperationDag, options: Options) -> DagSolution; - - fn dump(self: &OperationDag) -> String; + unsafe fn tag_operator_as_output(self: &mut DagBuilder<'_>, op: OperatorIndex); - fn concat(self: &mut OperationDag, other: &OperationDag); + fn optimize(self: &Dag, options: Options) -> DagSolution; #[namespace = "concrete_optimizer::dag"] fn dump(self: &CircuitSolution) -> String; @@ -761,9 +777,9 @@ mod ffi { #[namespace = "concrete_optimizer::weights"] fn number(weight: i64) -> Box; - fn tag_operator_as_output(self: &mut OperationDag, op: OperatorIndex); + fn get_circuit_count(self: &Dag) -> usize; - fn optimize_multi(self: &OperationDag, options: Options) -> CircuitSolution; + fn optimize_multi(self: &Dag, options: Options) -> CircuitSolution; fn NO_KEY_ID() -> u64; } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 77a08cc307..f3d55819d6 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -941,7 +941,8 @@ union MaybeUninit { struct PrivateFunctionalPackingBoostrapKey; struct CircuitKeys; namespace concrete_optimizer { - struct OperationDag; + struct Dag; + struct DagBuilder; struct Weights; enum class Encoding : ::std::uint8_t; enum class MultiParamStrategy : ::std::uint8_t; @@ -965,21 +966,37 @@ namespace concrete_optimizer { } namespace concrete_optimizer { -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag -#define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag -struct OperationDag final : public ::rust::Opaque { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Dag +#define CXXBRIDGE1_STRUCT_concrete_optimizer$Dag +struct Dag final : public ::rust::Opaque { + ::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept; + ::rust::String dump() const noexcept; + ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; + ::std::size_t get_circuit_count() const noexcept; + ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; + ~Dag() = delete; + +private: + friend ::rust::layout; + struct layout { + static ::std::size_t size() noexcept; + static ::std::size_t align() noexcept; + }; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Dag + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$DagBuilder +#define CXXBRIDGE1_STRUCT_concrete_optimizer$DagBuilder +struct DagBuilder final : public ::rust::Opaque { + ::rust::String dump() const noexcept; ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept; ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; ::concrete_optimizer::dag::OperatorIndex add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; - ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; - ::rust::String dump() const noexcept; - void concat(::concrete_optimizer::OperationDag const &other) noexcept; void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; - ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; - ~OperationDag() = delete; + ~DagBuilder() = delete; private: friend ::rust::layout; @@ -988,7 +1005,7 @@ struct OperationDag final : public ::rust::Opaque { static ::std::size_t align() noexcept; }; }; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$DagBuilder #ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Weights #define CXXBRIDGE1_STRUCT_concrete_optimizer$Weights @@ -1259,39 +1276,45 @@ namespace utils { extern "C" { void concrete_optimizer$utils$cxxbridge1$convert_to_dag_solution(::concrete_optimizer::v0::Solution const &solution, ::concrete_optimizer::dag::DagSolution *return$) noexcept; -void concrete_optimizer$utils$cxxbridge1$convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept; +void concrete_optimizer$utils$cxxbridge1$convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::Dag const &dag, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept; } // extern "C" } // namespace utils extern "C" { -::std::size_t concrete_optimizer$cxxbridge1$OperationDag$operator$sizeof() noexcept; -::std::size_t concrete_optimizer$cxxbridge1$OperationDag$operator$alignof() noexcept; +::std::size_t concrete_optimizer$cxxbridge1$Dag$operator$sizeof() noexcept; +::std::size_t concrete_optimizer$cxxbridge1$Dag$operator$alignof() noexcept; +::std::size_t concrete_optimizer$cxxbridge1$DagBuilder$operator$sizeof() noexcept; +::std::size_t concrete_optimizer$cxxbridge1$DagBuilder$operator$alignof() noexcept; } // extern "C" namespace dag { extern "C" { -::concrete_optimizer::OperationDag *concrete_optimizer$dag$cxxbridge1$empty() noexcept; +::concrete_optimizer::Dag *concrete_optimizer$dag$cxxbridge1$empty() noexcept; } // extern "C" } // namespace dag extern "C" { -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_input(::concrete_optimizer::OperationDag &self, ::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept; +::concrete_optimizer::DagBuilder *concrete_optimizer$cxxbridge1$Dag$builder(::concrete_optimizer::Dag &self, ::rust::String *circuit) noexcept; + +void concrete_optimizer$cxxbridge1$Dag$dump(::concrete_optimizer::Dag const &self, ::rust::String *return$) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_lut(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept; +void concrete_optimizer$cxxbridge1$DagBuilder$dump(::concrete_optimizer::DagBuilder const &self, ::rust::String *return$) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_dot(::concrete_optimizer::OperationDag &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::concrete_optimizer::Weights *weights) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_input(::concrete_optimizer::DagBuilder &self, ::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_levelled_op(::concrete_optimizer::OperationDag &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_lut(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_round_op(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_dot(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::concrete_optimizer::Weights *weights) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_unsafe_cast_op(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; -void concrete_optimizer$cxxbridge1$OperationDag$optimize(::concrete_optimizer::OperationDag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_round_op(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; -void concrete_optimizer$cxxbridge1$OperationDag$dump(::concrete_optimizer::OperationDag const &self, ::rust::String *return$) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_unsafe_cast_op(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; -void concrete_optimizer$cxxbridge1$OperationDag$concat(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::OperationDag const &other) noexcept; +void concrete_optimizer$cxxbridge1$DagBuilder$tag_operator_as_output(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex op) noexcept; + +void concrete_optimizer$cxxbridge1$Dag$optimize(::concrete_optimizer::Dag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept; } // extern "C" namespace dag { @@ -1316,9 +1339,9 @@ ::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$number(::st } // namespace weights extern "C" { -void concrete_optimizer$cxxbridge1$OperationDag$tag_operator_as_output(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex op) noexcept; +::std::size_t concrete_optimizer$cxxbridge1$Dag$get_circuit_count(::concrete_optimizer::Dag const &self) noexcept; -void concrete_optimizer$cxxbridge1$OperationDag$optimize_multi(::concrete_optimizer::OperationDag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept; +void concrete_optimizer$cxxbridge1$Dag$optimize_multi(::concrete_optimizer::Dag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept; ::std::uint64_t concrete_optimizer$cxxbridge1$NO_KEY_ID() noexcept; } // extern "C" @@ -1336,65 +1359,83 @@ ::concrete_optimizer::dag::DagSolution convert_to_dag_solution(::concrete_optimi return ::std::move(return$.value); } -::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag) noexcept { +::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::Dag const &dag) noexcept { ::rust::MaybeUninit<::concrete_optimizer::dag::CircuitSolution> return$; concrete_optimizer$utils$cxxbridge1$convert_to_circuit_solution(solution, dag, &return$.value); return ::std::move(return$.value); } } // namespace utils -::std::size_t OperationDag::layout::size() noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$operator$sizeof(); +::std::size_t Dag::layout::size() noexcept { + return concrete_optimizer$cxxbridge1$Dag$operator$sizeof(); +} + +::std::size_t Dag::layout::align() noexcept { + return concrete_optimizer$cxxbridge1$Dag$operator$alignof(); +} + +::std::size_t DagBuilder::layout::size() noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$operator$sizeof(); } -::std::size_t OperationDag::layout::align() noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$operator$alignof(); +::std::size_t DagBuilder::layout::align() noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$operator$alignof(); } namespace dag { -::rust::Box<::concrete_optimizer::OperationDag> empty() noexcept { - return ::rust::Box<::concrete_optimizer::OperationDag>::from_raw(concrete_optimizer$dag$cxxbridge1$empty()); +::rust::Box<::concrete_optimizer::Dag> empty() noexcept { + return ::rust::Box<::concrete_optimizer::Dag>::from_raw(concrete_optimizer$dag$cxxbridge1$empty()); } } // namespace dag -::concrete_optimizer::dag::OperatorIndex OperationDag::add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$add_input(*this, out_precision, out_shape); +::rust::Box<::concrete_optimizer::DagBuilder> Dag::builder(::rust::String circuit) noexcept { + return ::rust::Box<::concrete_optimizer::DagBuilder>::from_raw(concrete_optimizer$cxxbridge1$Dag$builder(*this, &circuit)); } -::concrete_optimizer::dag::OperatorIndex OperationDag::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$add_lut(*this, input, table, out_precision); +::rust::String Dag::dump() const noexcept { + ::rust::MaybeUninit<::rust::String> return$; + concrete_optimizer$cxxbridge1$Dag$dump(*this, &return$.value); + return ::std::move(return$.value); } -::concrete_optimizer::dag::OperatorIndex OperationDag::add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$add_dot(*this, inputs, weights.into_raw()); +::rust::String DagBuilder::dump() const noexcept { + ::rust::MaybeUninit<::rust::String> return$; + concrete_optimizer$cxxbridge1$DagBuilder$dump(*this, &return$.value); + return ::std::move(return$.value); } -::concrete_optimizer::dag::OperatorIndex OperationDag::add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$add_levelled_op(*this, inputs, lwe_dim_cost_factor, fixed_cost, manp, out_shape, comment); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_input(*this, out_precision, out_shape); } -::concrete_optimizer::dag::OperatorIndex OperationDag::add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$add_round_op(*this, input, rounded_precision); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_lut(*this, input, table, out_precision); } -::concrete_optimizer::dag::OperatorIndex OperationDag::add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$add_unsafe_cast_op(*this, input, rounded_precision); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_dot(*this, inputs, weights.into_raw()); } -::concrete_optimizer::dag::DagSolution OperationDag::optimize(::concrete_optimizer::Options options) const noexcept { - ::rust::MaybeUninit<::concrete_optimizer::dag::DagSolution> return$; - concrete_optimizer$cxxbridge1$OperationDag$optimize(*this, options, &return$.value); - return ::std::move(return$.value); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(*this, inputs, lwe_dim_cost_factor, fixed_cost, manp, out_shape, comment); } -::rust::String OperationDag::dump() const noexcept { - ::rust::MaybeUninit<::rust::String> return$; - concrete_optimizer$cxxbridge1$OperationDag$dump(*this, &return$.value); - return ::std::move(return$.value); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_round_op(*this, input, rounded_precision); } -void OperationDag::concat(::concrete_optimizer::OperationDag const &other) noexcept { - concrete_optimizer$cxxbridge1$OperationDag$concat(*this, other); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_unsafe_cast_op(*this, input, rounded_precision); +} + +void DagBuilder::tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept { + concrete_optimizer$cxxbridge1$DagBuilder$tag_operator_as_output(*this, op); +} + +::concrete_optimizer::dag::DagSolution Dag::optimize(::concrete_optimizer::Options options) const noexcept { + ::rust::MaybeUninit<::concrete_optimizer::dag::DagSolution> return$; + concrete_optimizer$cxxbridge1$Dag$optimize(*this, options, &return$.value); + return ::std::move(return$.value); } namespace dag { @@ -1429,13 +1470,13 @@ ::rust::Box<::concrete_optimizer::Weights> number(::std::int64_t weight) noexcep } } // namespace weights -void OperationDag::tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept { - concrete_optimizer$cxxbridge1$OperationDag$tag_operator_as_output(*this, op); +::std::size_t Dag::get_circuit_count() const noexcept { + return concrete_optimizer$cxxbridge1$Dag$get_circuit_count(*this); } -::concrete_optimizer::dag::CircuitSolution OperationDag::optimize_multi(::concrete_optimizer::Options options) const noexcept { +::concrete_optimizer::dag::CircuitSolution Dag::optimize_multi(::concrete_optimizer::Options options) const noexcept { ::rust::MaybeUninit<::concrete_optimizer::dag::CircuitSolution> return$; - concrete_optimizer$cxxbridge1$OperationDag$optimize_multi(*this, options, &return$.value); + concrete_optimizer$cxxbridge1$Dag$optimize_multi(*this, options, &return$.value); return ::std::move(return$.value); } @@ -1445,9 +1486,13 @@ ::std::uint64_t NO_KEY_ID() noexcept { } // namespace concrete_optimizer extern "C" { -::concrete_optimizer::OperationDag *cxxbridge1$box$concrete_optimizer$OperationDag$alloc() noexcept; -void cxxbridge1$box$concrete_optimizer$OperationDag$dealloc(::concrete_optimizer::OperationDag *) noexcept; -void cxxbridge1$box$concrete_optimizer$OperationDag$drop(::rust::Box<::concrete_optimizer::OperationDag> *ptr) noexcept; +::concrete_optimizer::Dag *cxxbridge1$box$concrete_optimizer$Dag$alloc() noexcept; +void cxxbridge1$box$concrete_optimizer$Dag$dealloc(::concrete_optimizer::Dag *) noexcept; +void cxxbridge1$box$concrete_optimizer$Dag$drop(::rust::Box<::concrete_optimizer::Dag> *ptr) noexcept; + +::concrete_optimizer::DagBuilder *cxxbridge1$box$concrete_optimizer$DagBuilder$alloc() noexcept; +void cxxbridge1$box$concrete_optimizer$DagBuilder$dealloc(::concrete_optimizer::DagBuilder *) noexcept; +void cxxbridge1$box$concrete_optimizer$DagBuilder$drop(::rust::Box<::concrete_optimizer::DagBuilder> *ptr) noexcept; ::concrete_optimizer::Weights *cxxbridge1$box$concrete_optimizer$Weights$alloc() noexcept; void cxxbridge1$box$concrete_optimizer$Weights$dealloc(::concrete_optimizer::Weights *) noexcept; @@ -1520,16 +1565,28 @@ void cxxbridge1$rust_vec$concrete_optimizer$dag$InstructionKeys$truncate(::rust: namespace rust { inline namespace cxxbridge1 { template <> -::concrete_optimizer::OperationDag *Box<::concrete_optimizer::OperationDag>::allocation::alloc() noexcept { - return cxxbridge1$box$concrete_optimizer$OperationDag$alloc(); +::concrete_optimizer::Dag *Box<::concrete_optimizer::Dag>::allocation::alloc() noexcept { + return cxxbridge1$box$concrete_optimizer$Dag$alloc(); +} +template <> +void Box<::concrete_optimizer::Dag>::allocation::dealloc(::concrete_optimizer::Dag *ptr) noexcept { + cxxbridge1$box$concrete_optimizer$Dag$dealloc(ptr); +} +template <> +void Box<::concrete_optimizer::Dag>::drop() noexcept { + cxxbridge1$box$concrete_optimizer$Dag$drop(this); +} +template <> +::concrete_optimizer::DagBuilder *Box<::concrete_optimizer::DagBuilder>::allocation::alloc() noexcept { + return cxxbridge1$box$concrete_optimizer$DagBuilder$alloc(); } template <> -void Box<::concrete_optimizer::OperationDag>::allocation::dealloc(::concrete_optimizer::OperationDag *ptr) noexcept { - cxxbridge1$box$concrete_optimizer$OperationDag$dealloc(ptr); +void Box<::concrete_optimizer::DagBuilder>::allocation::dealloc(::concrete_optimizer::DagBuilder *ptr) noexcept { + cxxbridge1$box$concrete_optimizer$DagBuilder$dealloc(ptr); } template <> -void Box<::concrete_optimizer::OperationDag>::drop() noexcept { - cxxbridge1$box$concrete_optimizer$OperationDag$drop(this); +void Box<::concrete_optimizer::DagBuilder>::drop() noexcept { + cxxbridge1$box$concrete_optimizer$DagBuilder$drop(this); } template <> ::concrete_optimizer::Weights *Box<::concrete_optimizer::Weights>::allocation::alloc() noexcept { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 4636598f50..b254a84586 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -922,7 +922,8 @@ std::size_t align_of() { struct PrivateFunctionalPackingBoostrapKey; struct CircuitKeys; namespace concrete_optimizer { - struct OperationDag; + struct Dag; + struct DagBuilder; struct Weights; enum class Encoding : ::std::uint8_t; enum class MultiParamStrategy : ::std::uint8_t; @@ -946,21 +947,37 @@ namespace concrete_optimizer { } namespace concrete_optimizer { -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag -#define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag -struct OperationDag final : public ::rust::Opaque { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Dag +#define CXXBRIDGE1_STRUCT_concrete_optimizer$Dag +struct Dag final : public ::rust::Opaque { + ::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept; + ::rust::String dump() const noexcept; + ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; + ::std::size_t get_circuit_count() const noexcept; + ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; + ~Dag() = delete; + +private: + friend ::rust::layout; + struct layout { + static ::std::size_t size() noexcept; + static ::std::size_t align() noexcept; + }; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Dag + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$DagBuilder +#define CXXBRIDGE1_STRUCT_concrete_optimizer$DagBuilder +struct DagBuilder final : public ::rust::Opaque { + ::rust::String dump() const noexcept; ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept; ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; ::concrete_optimizer::dag::OperatorIndex add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; - ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; - ::rust::String dump() const noexcept; - void concat(::concrete_optimizer::OperationDag const &other) noexcept; void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; - ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; - ~OperationDag() = delete; + ~DagBuilder() = delete; private: friend ::rust::layout; @@ -969,7 +986,7 @@ struct OperationDag final : public ::rust::Opaque { static ::std::size_t align() noexcept; }; }; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$DagBuilder #ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Weights #define CXXBRIDGE1_STRUCT_concrete_optimizer$Weights @@ -1237,11 +1254,11 @@ ::concrete_optimizer::v0::Solution optimize_bootstrap(::std::uint64_t precision, namespace utils { ::concrete_optimizer::dag::DagSolution convert_to_dag_solution(::concrete_optimizer::v0::Solution const &solution) noexcept; -::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag) noexcept; +::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::Dag const &dag) noexcept; } // namespace utils namespace dag { -::rust::Box<::concrete_optimizer::OperationDag> empty() noexcept; +::rust::Box<::concrete_optimizer::Dag> empty() noexcept; } // namespace dag namespace weights { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp index 1441e09ebd..5766b6ff03 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp @@ -48,11 +48,12 @@ TEST test_v0() { TEST test_dag_no_lut() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {3}; concrete_optimizer::dag::OperatorIndex node1 = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); std::vector inputs = {node1}; @@ -61,8 +62,8 @@ TEST test_dag_no_lut() { rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); - auto id = dag->add_dot(slice(inputs), std::move(weights)); - dag->tag_operator_as_output(id); + auto id = builder->add_dot(slice(inputs), std::move(weights)); + builder->tag_operator_as_output(id); auto solution = dag->optimize(default_options()); assert(solution.glwe_polynomial_size == 1); @@ -71,15 +72,16 @@ TEST test_dag_no_lut() { TEST test_dag_lut() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {3}; concrete_optimizer::dag::OperatorIndex input = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); std::vector table = {}; - auto id = dag->add_lut(input, slice(table), PRECISION_8B); - dag->tag_operator_as_output(id); + auto id = builder->add_lut(input, slice(table), PRECISION_8B); + builder->tag_operator_as_output(id); auto solution = dag->optimize(default_options()); assert(solution.glwe_dimension == 1); @@ -89,15 +91,16 @@ TEST test_dag_lut() { TEST test_dag_lut_wop() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {3}; concrete_optimizer::dag::OperatorIndex input = - dag->add_input(PRECISION_16B, slice(shape)); + builder->add_input(PRECISION_16B, slice(shape)); std::vector table = {}; - auto id = dag->add_lut(input, slice(table), PRECISION_16B); - dag->tag_operator_as_output(id); + auto id = builder->add_lut(input, slice(table), PRECISION_16B); + builder->tag_operator_as_output(id); auto solution = dag->optimize(default_options()); assert(solution.glwe_dimension == 2); @@ -107,15 +110,16 @@ TEST test_dag_lut_wop() { TEST test_dag_lut_force_wop() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {3}; concrete_optimizer::dag::OperatorIndex input = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); std::vector table = {}; - auto id = dag->add_lut(input, slice(table), PRECISION_8B); - dag->tag_operator_as_output(id); + auto id = builder->add_lut(input, slice(table), PRECISION_8B); + builder->tag_operator_as_output(id); auto options = default_options(); options.encoding = concrete_optimizer::Encoding::Crt; @@ -126,15 +130,16 @@ TEST test_dag_lut_force_wop() { TEST test_multi_parameters_1_precision() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {3}; concrete_optimizer::dag::OperatorIndex input = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); std::vector table = {}; - auto id = dag->add_lut(input, slice(table), PRECISION_8B); - dag->tag_operator_as_output(id); + auto id = builder->add_lut(input, slice(table), PRECISION_8B); + builder->tag_operator_as_output(id); auto options = default_options(); auto circuit_solution = dag->optimize_multi(options); @@ -152,18 +157,19 @@ TEST test_multi_parameters_1_precision() { TEST test_multi_parameters_2_precision() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {3}; concrete_optimizer::dag::OperatorIndex input1 = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); concrete_optimizer::dag::OperatorIndex input2 = - dag->add_input(PRECISION_1B, slice(shape)); + builder->add_input(PRECISION_1B, slice(shape)); std::vector table = {}; - auto lut1 = dag->add_lut(input1, slice(table), PRECISION_8B); - auto lut2 = dag->add_lut(input2, slice(table), PRECISION_8B); + auto lut1 = builder->add_lut(input1, slice(table), PRECISION_8B); + auto lut2 = builder->add_lut(input2, slice(table), PRECISION_8B); std::vector inputs = {lut1, lut2}; @@ -172,8 +178,8 @@ TEST test_multi_parameters_2_precision() { rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); - auto id = dag->add_dot(slice(inputs), std::move(weights)); - dag->tag_operator_as_output(id); + auto id = builder->add_dot(slice(inputs), std::move(weights)); + builder->tag_operator_as_output(id); auto options = default_options(); auto circuit_solution = dag->optimize_multi(options); @@ -192,18 +198,19 @@ TEST test_multi_parameters_2_precision() { TEST test_multi_parameters_2_precision_crt() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {3}; concrete_optimizer::dag::OperatorIndex input1 = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); concrete_optimizer::dag::OperatorIndex input2 = - dag->add_input(PRECISION_1B, slice(shape)); + builder->add_input(PRECISION_1B, slice(shape)); std::vector table = {}; - auto lut1 = dag->add_lut(input1, slice(table), PRECISION_8B); - auto lut2 = dag->add_lut(input2, slice(table), PRECISION_8B); + auto lut1 = builder->add_lut(input1, slice(table), PRECISION_8B); + auto lut2 = builder->add_lut(input2, slice(table), PRECISION_8B); std::vector inputs = {lut1, lut2}; @@ -212,8 +219,8 @@ TEST test_multi_parameters_2_precision_crt() { rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); - auto id = dag->add_dot(slice(inputs), std::move(weights)); - dag->tag_operator_as_output(id); + auto id = builder->add_dot(slice(inputs), std::move(weights)); + builder->tag_operator_as_output(id); auto options = default_options(); options.encoding = concrete_optimizer::Encoding::Crt; @@ -228,25 +235,26 @@ TEST test_multi_parameters_2_precision_crt() { TEST test_composable_dag_mono_fallback_on_dag_multi() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {}; concrete_optimizer::dag::OperatorIndex input1 = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); std::vector inputs = {input1}; std::vector weight_vec = {1 << 8}; rust::cxxbridge1::Box weights1 = concrete_optimizer::weights::vector(slice(weight_vec)); - input1 = dag->add_dot(slice(inputs), std::move(weights1)); + input1 = builder->add_dot(slice(inputs), std::move(weights1)); std::vector table = {}; - auto lut1 = dag->add_lut(input1, slice(table), PRECISION_8B); + auto lut1 = builder->add_lut(input1, slice(table), PRECISION_8B); std::vector lut1v = {lut1}; rust::cxxbridge1::Box weights2 = concrete_optimizer::weights::vector(slice(weight_vec)); - auto id = dag->add_dot(slice(lut1v), std::move(weights2)); - dag->tag_operator_as_output(id); + auto id = builder->add_dot(slice(lut1v), std::move(weights2)); + builder->tag_operator_as_output(id); auto options = default_options(); auto solution1 = dag->optimize(options); @@ -262,11 +270,12 @@ TEST test_composable_dag_mono_fallback_on_dag_multi() { TEST test_non_composable_dag_mono_fallback_on_woppbs() { auto dag = concrete_optimizer::dag::empty(); + auto builder = dag->builder("test"); std::vector shape = {}; concrete_optimizer::dag::OperatorIndex input1 = - dag->add_input(PRECISION_8B, slice(shape)); + builder->add_input(PRECISION_8B, slice(shape)); std::vector inputs = {input1}; @@ -274,14 +283,14 @@ TEST test_non_composable_dag_mono_fallback_on_woppbs() { rust::cxxbridge1::Box weights1 = concrete_optimizer::weights::vector(slice(weight_vec)); - input1 = dag->add_dot(slice(inputs), std::move(weights1)); + input1 = builder->add_dot(slice(inputs), std::move(weights1)); std::vector table = {}; - auto lut1 = dag->add_lut(input1, slice(table), PRECISION_8B); + auto lut1 = builder->add_lut(input1, slice(table), PRECISION_8B); std::vector lut1v = {lut1}; rust::cxxbridge1::Box weights2 = concrete_optimizer::weights::vector(slice(weight_vec)); - auto id = dag->add_dot(slice(lut1v), std::move(weights2)); - dag->tag_operator_as_output(id); + auto id = builder->add_dot(slice(lut1v), std::move(weights2)); + builder->tag_operator_as_output(id); auto options = default_options(); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs index c52fcbf1dd..5561d005a8 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs @@ -1,5 +1,6 @@ use std::fmt; use std::iter::{empty, once}; +use std::ops::Deref; use crate::dag::operator::tensor::{ClearTensor, Shape}; @@ -116,8 +117,20 @@ impl Operator { } #[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub struct OperatorIndex { - pub i: usize, +pub struct OperatorIndex(pub usize); + +impl Deref for OperatorIndex { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for OperatorIndex { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } } impl fmt::Display for Operator { @@ -135,21 +148,21 @@ impl fmt::Display for Operator { if i > 0 { write!(f, " + ")?; } - write!(f, "{weight} x %{}", input.i)?; + write!(f, "{weight} x %{}", input.0)?; } } Self::UnsafeCast { input, out_precision, } => { - write!(f, "%{} : u{out_precision}", input.i)?; + write!(f, "%{} : u{out_precision}", input.0)?; } Self::Lut { input, out_precision, .. } => { - write!(f, "LUT[%{}] : u{out_precision}", input.i)?; + write!(f, "LUT[%{}] : u{out_precision}", input.0)?; } Self::LevelledOp { inputs, @@ -162,7 +175,7 @@ impl fmt::Display for Operator { if i > 0 { write!(f, ", ")?; } - write!(f, "%{}", input.i)?; + write!(f, "%{}", input.0)?; } write!(f, "] : manp={manp} x {out_shape:?}")?; } @@ -170,7 +183,7 @@ impl fmt::Display for Operator { input, out_precision, } => { - write!(f, "ROUND[%{}] : u{out_precision}", input.i)?; + write!(f, "ROUND[%{}] : u{out_precision}", input.0)?; } } Ok(()) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs index b9ea66552f..94c9ce3e58 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs @@ -1,6 +1,6 @@ use crate::dag::operator::operator::Operator; use crate::dag::operator::OperatorIndex; -use crate::dag::unparametrized::OperationDag; +use crate::dag::unparametrized::Dag; fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { let mut op = op.clone(); @@ -8,10 +8,10 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { Operator::Input { .. } => (), Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } - | Operator::Round { input, .. } => input.i = old_index_to_new[input.i], + | Operator::Round { input, .. } => input.0 = old_index_to_new[input.0], Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { for input in inputs { - input.i = old_index_to_new[input.i]; + input.0 = old_index_to_new[input.0]; } } }; @@ -19,16 +19,16 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { } pub(crate) fn regen( - dag: &OperationDag, - f: &mut dyn FnMut(usize, &Operator, &mut OperationDag) -> Option, -) -> (OperationDag, Vec>) { - let mut regen_dag = OperationDag::new(); + dag: &Dag, + f: &mut dyn FnMut(usize, &Operator, &mut Dag) -> Option, +) -> (Dag, Vec>) { + let mut regen_dag = Dag::new(); let mut old_index_to_new = vec![]; for (i, op) in dag.operators.iter().enumerate() { let op = reindex_op_inputs(op, &old_index_to_new); let size = regen_dag.operators.len(); if let Some(op_i) = f(i, &op, &mut regen_dag) { - old_index_to_new.push(op_i.i); + old_index_to_new.push(op_i.0); } else { assert!(size == regen_dag.operators.len()); old_index_to_new.push(regen_dag.len()); @@ -36,6 +36,7 @@ pub(crate) fn regen( regen_dag.out_precisions.push(dag.out_precisions[i]); regen_dag.out_shapes.push(dag.out_shapes[i].clone()); regen_dag.output_tags.push(dag.output_tags[i]); + regen_dag.circuit_tags.push(dag.circuit_tags[i].clone()); } } (regen_dag, instructions_multi_map(&old_index_to_new)) @@ -48,11 +49,7 @@ fn instructions_multi_map(old_index_to_new: &[usize]) -> Vec> for &new_instr in old_index_to_new { let start_from = last_new_instr.map_or(new_instr, |v: usize| v + 1); if start_from <= new_instr { - result.push( - (start_from..=new_instr) - .map(|i| OperatorIndex { i }) - .collect(), - ); + result.push((start_from..=new_instr).map(OperatorIndex).collect()); } else { result.push(vec![]); } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs index 04487ecef5..c243a33fa1 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs @@ -1,9 +1,9 @@ use crate::dag::operator::{Operator, OperatorIndex}; -use crate::dag::unparametrized::OperationDag; +use crate::dag::unparametrized::Dag; use super::regen::regen; -fn regen_round(_: usize, op: &Operator, dag: &mut OperationDag) -> Option { +fn regen_round(_: usize, op: &Operator, dag: &mut Dag) -> Option { match *op { Operator::Round { input, @@ -13,12 +13,10 @@ fn regen_round(_: usize, op: &Operator, dag: &mut OperationDag) -> Option OperationDag { +pub(crate) fn expand_round(dag: &Dag) -> Dag { regen(dag, &mut regen_round).0 } -pub(crate) fn expand_round_and_index_map( - dag: &OperationDag, -) -> (OperationDag, Vec>) { +pub(crate) fn expand_round_and_index_map(dag: &Dag) -> (Dag, Vec>) { regen(dag, &mut regen_round) } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index e235769d10..6c2f3c7e69 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -1,52 +1,106 @@ -use std::fmt; -use std::fmt::Write; - use crate::dag::operator::{ dot_kind, DotKind, FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights, }; +use std::{collections::HashSet, fmt}; + +/// The name of the default. Used when adding operations directly on the dag instead of via a builder. +const DEFAULT_CIRCUIT: &str = "_"; + +/// A type referencing every informations related to an operator of the dag. +#[derive(Debug, Clone)] +#[allow(unused)] +pub(crate) struct DagOperator<'dag> { + pub(crate) id: OperatorIndex, + pub(crate) dag: &'dag Dag, + pub(crate) operator: &'dag Operator, + pub(crate) shape: &'dag Shape, + pub(crate) precision: &'dag Precision, + pub(crate) output_tag: &'dag bool, + pub(crate) circuit_tag: &'dag String, +} -pub(crate) type UnparameterizedOperator = Operator; +impl<'dag> DagOperator<'dag> { + /// Returns if the operator is an input. + pub(crate) fn is_input(&self) -> bool { + matches!(self.operator, Operator::Input { .. }) + } -#[derive(Clone, PartialEq, Debug)] -#[must_use] -pub struct OperationDag { - pub(crate) operators: Vec, - // Collect all operators output shape - pub(crate) out_shapes: Vec, - // Collect all operators output precision - pub(crate) out_precisions: Vec, - // Collect whether operators are tagged as outputs - pub(crate) output_tags: Vec, + /// Returns if the operator is an output. + pub(crate) fn is_output(&self) -> bool { + *self.output_tag + } + + /// Returns an iterator over the operators used as input to this operator. + pub(crate) fn get_inputs_iter(&self) -> impl Iterator> + '_ { + self.operator + .get_inputs_iter() + .map(|id| self.dag.get_operator(*id)) + } } -impl fmt::Display for OperationDag { +/// A structure referencing the operators associated to a particular circuit. +#[derive(Debug, Clone)] +pub struct DagCircuit<'dag> { + pub(crate) dag: &'dag Dag, + pub(crate) ids: Vec, + pub(crate) circuit: String, +} + +impl<'dag> DagCircuit<'dag> { + /// Returns an iterator over the operators of this circuit. + pub(crate) fn get_operators_iter(&self) -> impl Iterator> + '_ { + self.ids.iter().map(|id| self.dag.get_operator(*id)) + } + + /// Returns an iterator over the circuit's input operators. + #[allow(unused)] + pub(crate) fn get_input_operators_iter(&self) -> impl Iterator> + '_ { + self.get_operators_iter().filter(DagOperator::is_input) + } + + /// Returns an iterator over the circuit's output operators. + #[allow(unused)] + pub(crate) fn get_output_operators_iter(&self) -> impl Iterator> + '_ { + self.get_operators_iter().filter(DagOperator::is_output) + } +} + +impl<'dag> fmt::Display for DagCircuit<'dag> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - for (i, op) in self.operators.iter().enumerate() { - writeln!(f, "%{i} <- {op}")?; + for op in self.get_operators_iter() { + writeln!(f, "{} <- {:?}", op.id, op.operator)?; } Ok(()) } } -impl OperationDag { - pub const fn new() -> Self { - Self { - operators: vec![], - out_shapes: vec![], - out_precisions: vec![], - output_tags: vec![], - } - } +/// A type allowing build a circuit in a dag. +/// +/// See [Dag] for more informations on dag building. +#[derive(Debug)] +pub struct DagBuilder<'dag> { + dag: &'dag mut Dag, + pub(crate) circuit: String, +} - fn add_operator(&mut self, operator: UnparameterizedOperator) -> OperatorIndex { - let i = self.operators.len(); - self.out_precisions +impl<'dag> DagBuilder<'dag> { + fn add_operator(&mut self, operator: Operator) -> OperatorIndex { + debug_assert!(operator + .get_inputs_iter() + .all(|id| self.dag.circuit_tags[id.0] == self.circuit)); + let i = self.dag.operators.len(); + self.dag + .out_precisions .push(self.infer_out_precision(&operator)); - self.out_shapes.push(self.infer_out_shape(&operator)); - self.operators.push(operator); - self.output_tags.push(false); - OperatorIndex { i } + self.dag.out_shapes.push(self.infer_out_shape(&operator)); + operator + .get_inputs_iter() + .for_each(|id| self.dag.output_tags[id.0] = false); + self.dag.operators.push(operator); + self.dag.output_tags.push(true); + self.dag.circuit_tags.push(self.circuit.clone()); + OperatorIndex(i) } pub fn add_input( @@ -110,7 +164,7 @@ impl OperationDag { input: OperatorIndex, out_precision: Precision, ) -> OperatorIndex { - let input_precision = self.out_precisions[input.i]; + let input_precision = self.dag.out_precisions[input.0]; if input_precision == out_precision { return input; } @@ -125,7 +179,7 @@ impl OperationDag { input: OperatorIndex, rounded_precision: Precision, ) -> OperatorIndex { - let in_precision = self.out_precisions[input.i]; + let in_precision = self.dag.out_precisions[input.0]; assert!(rounded_precision <= in_precision); self.add_operator(Operator::Round { input, @@ -133,30 +187,10 @@ impl OperationDag { }) } - pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) { - assert!(operator.i < self.len()); - self.output_tags[operator.i] = true; - } - - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { - self.operators.len() - } - - pub fn dump(&self) -> String { - let mut acc = String::new(); - let err_msg = "Optimizer: Can't dump OperationDag"; - writeln!(acc, "Dag:").expect(err_msg); - for (i, op) in self.operators.iter().enumerate() { - writeln!(acc, "%{i} <- {op:?}").expect(err_msg); - } - acc - } - fn add_shift_left_lsb_to_msb_no_padding(&mut self, input: OperatorIndex) -> OperatorIndex { // Convert any input to simple 1bit msb replacing the padding // For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding. - let in_precision = self.out_precisions[input.i]; + let in_precision = self.dag.out_precisions[input.0]; let shift_factor = Weights::number(1 << (in_precision as i64)); let lsb_as_msb = self.add_dot([input], shift_factor); self.add_unsafe_cast(lsb_as_msb, 0 as Precision) @@ -169,7 +203,7 @@ impl OperationDag { out_precision: Precision, ) -> OperatorIndex { // For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding. - let in_precision = self.out_precisions[input.i]; + let in_precision = self.dag.out_precisions[input.0]; assert!(in_precision == 0); // An add after with a clear constant is skipped here as it doesn't change noise handling. self.add_lut(input, table, out_precision) @@ -189,7 +223,7 @@ impl OperationDag { // The lowest bit is converted to a ciphertext of same precision as input. // Introduce a pbs of input precision but this precision is only used on 1 levelled op and converted to lower precision // Noise is reduced by a pbs. - let out_precision = self.out_precisions[input.i]; + let out_precision = self.dag.out_precisions[input.0]; let lsb_as_msb = self.add_shift_left_lsb_to_msb_no_padding(input); self.add_shift_right_msb_no_padding_to_lsb(lsb_as_msb, out_precision) } @@ -197,7 +231,7 @@ impl OperationDag { pub fn add_truncate_1_bit(&mut self, input: OperatorIndex) -> OperatorIndex { // Reset a bit. // ex: 10110 is truncated to 1011, 10111 is truncated to 1011 - let in_precision = self.out_precisions[input.i]; + let in_precision = self.dag.out_precisions[input.0]; let lowest_bit = self.add_isolate_lowest_bit(input); let bit_cleared = self.add_dot([input, lowest_bit], [1, -1]); self.add_unsafe_cast(bit_cleared, in_precision - 1) @@ -214,7 +248,7 @@ impl OperationDag { // Note: this is a simplified graph, some constant additions are missing without consequence on crypto parameter choice. // Note: reset and rounds could be done by 4, 3, 2 and 1 bits groups for efficiency. // bit efficiency is better for 4 precision then 3, but the feasibility is lower for high noise - let in_precision = self.out_precisions[input.i]; + let in_precision = self.dag.out_precisions[input.0]; assert!(rounded_precision <= in_precision); if in_precision == rounded_precision { return input; @@ -252,84 +286,35 @@ impl OperationDag { self.add_lut(rounded, table, out_precision) } - /// Concatenates two dags into a single one (with two disconnected clusters). - pub fn concat(&mut self, other: &Self) { - let length = self.len(); - self.operators.extend(other.operators.iter().cloned()); - self.out_precisions.extend(other.out_precisions.iter()); - self.out_shapes.extend(other.out_shapes.iter().cloned()); - self.output_tags.extend(other.output_tags.iter()); - self.operators[length..] - .iter_mut() - .for_each(|node| match node { - Operator::Lut { ref mut input, .. } - | Operator::UnsafeCast { ref mut input, .. } - | Operator::Round { ref mut input, .. } => { - input.i += length; - } - Operator::Dot { ref mut inputs, .. } - | Operator::LevelledOp { ref mut inputs, .. } => { - inputs.iter_mut().for_each(|inp| inp.i += length); - } - _ => (), - }); - } - - /// Returns an iterator over input nodes indices. - pub(crate) fn get_input_index_iter(&self) -> impl Iterator + '_ { - self.operators - .iter() - .enumerate() - .filter_map(|(index, op)| match op { - Operator::Input { .. } => Some(index), - _ => None, - }) - } - - /// If no outputs were declared, automatically tag final nodes as outputs. - #[allow(unused)] - pub(crate) fn detect_outputs(&mut self) { - assert!(!self.is_output_tagged()); - self.output_tags = vec![true; self.len()]; - self.operators - .iter() - .flat_map(|op| op.get_inputs_iter()) - .for_each(|op| self.output_tags[op.i] = false); - } - - fn is_output_tagged(&self) -> bool { - self.output_tags - .iter() - .copied() - .reduce(|a, b| a || b) - .unwrap() - } - - /// Returns an iterator over output nodes indices. - pub(crate) fn get_output_index_iter(&self) -> impl Iterator + '_ { - self.output_tags - .iter() - .enumerate() - .filter_map(|(index, is_output)| is_output.then_some(index)) + /// Marks an operator as being an output of the circuit. + /// + /// # Note: + /// + /// Operators without consumers are automatically tagged as output. For operators that are used + /// as input to another operator, but at the same time are returned by the circuit, they must be + /// tagged using this method. + pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) { + assert!(operator.0 < self.dag.len()); + debug_assert!(self.dag.circuit_tags[operator.0] == self.circuit); + self.dag.output_tags[operator.0] = true; } - /// Returns whether the node is tagged as output. - pub(crate) fn is_output_node(&self, oid: usize) -> bool { - self.output_tags[oid] + pub fn get_circuit(&self) -> DagCircuit<'_> { + self.dag.get_circuit(&self.circuit) } - fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape { + fn infer_out_shape(&self, op: &Operator) -> Shape { match op { Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => { out_shape.clone() } Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } - | Operator::Round { input, .. } => self.out_shapes[input.i].clone(), + | Operator::Round { input, .. } => self.dag.out_shapes[input.0].clone(), Operator::Dot { inputs, weights, .. } => { - let input_shape = self.out_shapes[inputs[0].i].clone(); + let input_shape = self.dag.out_shapes[inputs[0].0].clone(); let kind = dot_kind(inputs.len() as u64, &input_shape, weights); match kind { DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => { @@ -353,74 +338,355 @@ impl OperationDag { } } - fn infer_out_precision(&self, op: &UnparameterizedOperator) -> Precision { + fn infer_out_precision(&self, op: &Operator) -> Precision { match op { Operator::Input { out_precision, .. } | Operator::Lut { out_precision, .. } | Operator::UnsafeCast { out_precision, .. } | Operator::Round { out_precision, .. } => *out_precision, Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { - self.out_precisions[inputs[0].i] + self.dag.out_precisions[inputs[0].0] } } } } +/// A type containing a Directed Acyclic Graph of operators. +/// +/// This is the major datatype used to encode a module in the optimizer. It is equivalent to an fhe +/// module in the frontend, or to a mlir module in the compiler: it can contain multiple separated +/// circuits which are optimized together. +/// +/// To add a new circuit to the dag, one should create a [`DagBuilder`] associated to the circuit +/// using the [`Dag::builder`] method, and use the `DagBuilder::add_*` methods. +/// +/// # Note: +/// +/// For ease of use in tests, it is also possible to add operators on an anonymous circuit (`_`) +/// directly on a [`Dag`] object itself, using the `Dag::add_*` methods. +#[derive(Clone, PartialEq, Debug)] +#[must_use] +pub struct Dag { + pub(crate) operators: Vec, + // Collect all operators output shape + pub(crate) out_shapes: Vec, + // Collect all operators output precision + pub(crate) out_precisions: Vec, + // Collect whether operators are tagged as outputs + pub(crate) output_tags: Vec, + // Collect the circuit the operators are associated with + pub(crate) circuit_tags: Vec, +} + +impl fmt::Display for Dag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for op in self.get_operators_iter() { + writeln!(f, "{} <- {:?}", op.id, op.operator)?; + } + Ok(()) + } +} + +impl Default for Dag { + fn default() -> Self { + Self::new() + } +} + +impl Dag { + /// Instantiate a new dag. + pub fn new() -> Self { + Self { + operators: vec![], + out_shapes: vec![], + out_precisions: vec![], + output_tags: vec![], + circuit_tags: vec![], + } + } + + /// Returns a builder for the circuit named `circuit`. + pub fn builder>(&mut self, circuit: A) -> DagBuilder { + DagBuilder { + dag: self, + circuit: circuit.as_ref().into(), + } + } + + pub fn add_input( + &mut self, + out_precision: Precision, + out_shape: impl Into, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_input(out_precision, out_shape) + } + + pub fn add_lut( + &mut self, + input: OperatorIndex, + table: FunctionTable, + out_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_lut(input, table, out_precision) + } + + pub fn add_dot( + &mut self, + inputs: impl Into>, + weights: impl Into, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT).add_dot(inputs, weights) + } + + pub fn add_levelled_op( + &mut self, + inputs: impl Into>, + complexity: LevelledComplexity, + manp: f64, + out_shape: impl Into, + comment: impl Into, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_levelled_op(inputs, complexity, manp, out_shape, comment) + } + + pub fn add_unsafe_cast( + &mut self, + input: OperatorIndex, + out_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_unsafe_cast(input, out_precision) + } + + pub fn add_round_op( + &mut self, + input: OperatorIndex, + rounded_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_round_op(input, rounded_precision) + } + + #[allow(unused)] + fn add_shift_left_lsb_to_msb_no_padding(&mut self, input: OperatorIndex) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_shift_left_lsb_to_msb_no_padding(input) + } + + #[allow(unused)] + fn add_lut_1bit_no_padding( + &mut self, + input: OperatorIndex, + table: FunctionTable, + out_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_lut_1bit_no_padding(input, table, out_precision) + } + + #[allow(unused)] + fn add_shift_right_msb_no_padding_to_lsb( + &mut self, + input: OperatorIndex, + out_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_shift_right_msb_no_padding_to_lsb(input, out_precision) + } + + #[allow(unused)] + fn add_isolate_lowest_bit(&mut self, input: OperatorIndex) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT).add_isolate_lowest_bit(input) + } + + pub fn add_truncate_1_bit(&mut self, input: OperatorIndex) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT).add_truncate_1_bit(input) + } + + pub fn add_expanded_round( + &mut self, + input: OperatorIndex, + rounded_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT) + .add_expanded_round(input, rounded_precision) + } + + pub fn add_expanded_rounded_lut( + &mut self, + input: OperatorIndex, + table: FunctionTable, + rounded_precision: Precision, + out_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT).add_expanded_rounded_lut( + input, + table, + rounded_precision, + out_precision, + ) + } + + pub fn add_rounded_lut( + &mut self, + input: OperatorIndex, + table: FunctionTable, + rounded_precision: Precision, + out_precision: Precision, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT).add_expanded_rounded_lut( + input, + table, + rounded_precision, + out_precision, + ) + } + + /// Returns an iterator over the operator indices. + pub(crate) fn get_indices_iter(&self) -> impl Iterator { + (0..self.len()).map(OperatorIndex) + } + + /// Returns an iterator over the circuits contained in the dag. + pub(crate) fn get_circuits_iter(&self) -> impl Iterator> + '_ { + let mut circuits: HashSet = HashSet::new(); + self.circuit_tags.iter().for_each(|name| { + let _ = circuits.insert(name.to_owned()); + }); + circuits + .into_iter() + .map(|circuit| self.get_circuit(circuit)) + } + + /// Returns a circuit object from its name. + /// + /// # Note: + /// + /// Panics if no circuit with the given name exist in the dag. + pub(crate) fn get_circuit>(&self, circuit: A) -> DagCircuit { + let circuit = circuit.as_ref().to_string(); + assert!(self.circuit_tags.contains(&circuit)); + let ids = self + .circuit_tags + .iter() + .enumerate() + .filter_map(|(id, circ)| (*circ == circuit).then_some(OperatorIndex(id))) + .collect(); + DagCircuit { + dag: self, + circuit, + ids, + } + } + + /// Returns an iterator over the input operators of the dag. + pub(crate) fn get_input_operators_iter(&self) -> impl Iterator> { + self.get_indices_iter() + .map(|i| self.get_operator(i)) + .filter(DagOperator::is_input) + } + + /// Returns an iterator over the outputs operators of the dag. + pub(crate) fn get_output_operators_iter(&self) -> impl Iterator> { + self.get_indices_iter() + .map(|i| self.get_operator(i)) + .filter(DagOperator::is_output) + } + + /// Returns an iterator over the operators of the dag. + pub(crate) fn get_operators_iter(&self) -> impl Iterator> { + self.get_indices_iter().map(|i| self.get_operator(i)) + } + + /// Returns an operator from its operator index. + /// + /// # Note: + /// + /// Panics if the operator index is invalid. + pub(crate) fn get_operator(&self, id: OperatorIndex) -> DagOperator<'_> { + assert!(id.0 < self.len()); + DagOperator { + dag: self, + id, + operator: self.operators.get(id.0).unwrap(), + shape: self.out_shapes.get(id.0).unwrap(), + precision: self.out_precisions.get(id.0).unwrap(), + output_tag: self.output_tags.get(id.0).unwrap(), + circuit_tag: self.circuit_tags.get(id.0).unwrap(), + } + } + + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.operators.len() + } + + /// Marks an operator as being an output of a circuit. + /// + /// # Note: + /// + /// Operators without consumers are automatically tagged as output. For operators that are used + /// as input to another operator, but at the same time are returned by the circuit, they must be + /// tagged using this method. + pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) { + assert!(operator.0 < self.len()); + self.output_tags[operator.0] = true; + } + + /// Returns the number of circuits in the dag. + pub fn get_circuit_count(&self) -> usize { + self.get_circuits_iter().count() + } +} + #[cfg(test)] mod tests { use super::*; use crate::dag::operator::Shape; #[test] - fn graph_concat() { - let mut graph1 = OperationDag::new(); - let a = graph1.add_input(1, Shape::number()); - let b = graph1.add_input(1, Shape::number()); - let c = graph1.add_dot([a, b], [1, 1]); - let _d = graph1.add_lut(c, FunctionTable::UNKWOWN, 1); - let mut graph2 = OperationDag::new(); - let a = graph2.add_input(2, Shape::number()); - let b = graph2.add_input(2, Shape::number()); - let c = graph2.add_dot([a, b], [2, 2]); - let _d = graph2.add_lut(c, FunctionTable::UNKWOWN, 2); - graph1.concat(&graph2); - - let mut graph3 = OperationDag::new(); - let a = graph3.add_input(1, Shape::number()); - let b = graph3.add_input(1, Shape::number()); - let c = graph3.add_dot([a, b], [1, 1]); - let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 1); - let a = graph3.add_input(2, Shape::number()); - let b = graph3.add_input(2, Shape::number()); - let c = graph3.add_dot([a, b], [2, 2]); - let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 2); - - assert_eq!(graph1, graph3); + #[allow(clippy::many_single_char_names)] + fn graph_builder() { + let mut graph = Dag::new(); + let mut builder = graph.builder("main1"); + let a = builder.add_input(1, Shape::number()); + let b = builder.add_input(1, Shape::number()); + let c = builder.add_dot([a, b], [1, 1]); + let _d = builder.add_lut(c, FunctionTable::UNKWOWN, 1); + let mut builder = graph.builder("main2"); + let e = builder.add_input(2, Shape::number()); + let f = builder.add_input(2, Shape::number()); + let g = builder.add_dot([e, f], [2, 2]); + let _h = builder.add_lut(g, FunctionTable::UNKWOWN, 2); + graph.tag_operator_as_output(c); } #[test] fn graph_creation() { - let mut graph = OperationDag::new(); - - let input1 = graph.add_input(1, Shape::number()); + let mut graph = Dag::new(); + let mut builder = graph.builder("_"); + let input1 = builder.add_input(1, Shape::number()); - let input2 = graph.add_input(2, Shape::number()); + let input2 = builder.add_input(2, Shape::number()); let cpx_add = LevelledComplexity::ADDITION; - let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); + let sum1 = builder.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); - let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 1); + let lut1 = builder.add_lut(sum1, FunctionTable::UNKWOWN, 1); let concat = - graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat"); + builder.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat"); - let dot = graph.add_dot([concat], [1, 2]); + let dot = builder.add_dot([concat], [1, 2]); - let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2); + let lut2 = builder.add_lut(dot, FunctionTable::UNKWOWN, 2); let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2]; for (expected_i, op_index) in ops_index.iter().enumerate() { - assert_eq!(expected_i, op_index.i); + assert_eq!(expected_i, op_index.0); } assert_eq!( @@ -471,7 +737,7 @@ mod tests { #[test] fn test_rounded_lut() { - let mut graph = OperationDag::new(); + let mut graph = Dag::new(); let out_precision = 5; let rounded_precision = 2; let input1 = graph.add_input(out_precision, Shape::number()); @@ -494,77 +760,77 @@ mod tests { weights: Weights::number(1 << 5), }, Operator::UnsafeCast { - input: OperatorIndex { i: 1 }, + input: OperatorIndex(1), out_precision: 0, }, //// 1 Bit to out_precision Operator::Lut { - input: OperatorIndex { i: 2 }, + input: OperatorIndex(2), table: FunctionTable::UNKWOWN, out_precision: 5, }, //// Erase bit Operator::Dot { - inputs: vec![input1, OperatorIndex { i: 3 }], + inputs: vec![input1, OperatorIndex(3)], weights: Weights::vector([1, -1]), }, Operator::UnsafeCast { - input: OperatorIndex { i: 4 }, + input: OperatorIndex(4), out_precision: 4, }, // Clear: cleared = input - bit0 - bit1 //// Extract bit Operator::Dot { - inputs: vec![OperatorIndex { i: 5 }], + inputs: vec![OperatorIndex(5)], weights: Weights::number(1 << 4), }, Operator::UnsafeCast { - input: OperatorIndex { i: 6 }, + input: OperatorIndex(6), out_precision: 0, }, //// 1 Bit to out_precision Operator::Lut { - input: OperatorIndex { i: 7 }, + input: OperatorIndex(7), table: FunctionTable::UNKWOWN, out_precision: 4, }, //// Erase bit Operator::Dot { - inputs: vec![OperatorIndex { i: 5 }, OperatorIndex { i: 8 }], + inputs: vec![OperatorIndex(5), OperatorIndex(8)], weights: Weights::vector([1, -1]), }, Operator::UnsafeCast { - input: OperatorIndex { i: 9 }, + input: OperatorIndex(9), out_precision: 3, }, // Clear: cleared = input - bit0 - bit1 - bit2 //// Extract bit Operator::Dot { - inputs: vec![OperatorIndex { i: 10 }], + inputs: vec![OperatorIndex(10)], weights: Weights::number(1 << 3), }, Operator::UnsafeCast { - input: OperatorIndex { i: 11 }, + input: OperatorIndex(11), out_precision: 0, }, //// 1 Bit to out_precision Operator::Lut { - input: OperatorIndex { i: 12 }, + input: OperatorIndex(12), table: FunctionTable::UNKWOWN, out_precision: 3, }, //// Erase bit Operator::Dot { - inputs: vec![OperatorIndex { i: 10 }, OperatorIndex { i: 13 }], + inputs: vec![OperatorIndex(10), OperatorIndex(13)], weights: Weights::vector([1, -1]), }, Operator::UnsafeCast { - input: OperatorIndex { i: 14 }, + input: OperatorIndex(14), out_precision: 2, }, // Lut on rounded precision Operator::Lut { - input: OperatorIndex { i: 15 }, + input: OperatorIndex(15), table: FunctionTable::UNKWOWN, out_precision: 5, }, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index bf1450803e..1f8bb65f45 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -46,7 +46,7 @@ pub struct AnalyzedDag { } pub fn analyze( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, noise_config: &NoiseBoundConfig, p_cut: &Option, default_partition: PartitionIndex, @@ -70,8 +70,8 @@ pub fn analyze( check_composability(&dag, &out_variances, nb_partitions)?; // Get the largest output out_variance let largest_output_variances = dag - .get_output_index_iter() - .map(|index| out_variances[index].clone()) + .get_output_operators_iter() + .map(|op| out_variances[op.id.0].clone()) .reduce(|lhs, rhs| { lhs.into_iter() .zip(rhs) @@ -110,7 +110,7 @@ pub fn analyze( } fn check_composability( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, symbolic_variances: &[Vec], nb_partitions: usize, ) -> Result<()> { @@ -126,8 +126,8 @@ fn check_composability( // If the circuit outputs are free from input variances, it means that every outputs are // refreshed, and the function can be composed. let in_var_in_out_var = dag - .get_output_index_iter() - .flat_map(|index| symbolic_variances[index].iter().map(move |v| (index, v))) + .get_output_operators_iter() + .flat_map(|op| symbolic_variances[op.id.0].iter().map(move |v| (op.id, v))) .find_map(|(output_index, sym_var)| { (0..nb_partitions) .find(|partition| { @@ -164,8 +164,8 @@ pub fn original_instrs_partition( // let mut extra_conversion_keys = None; for (i, new_instruction) in new_instructions.iter().enumerate() { // focus on TLU information - let new_instr_part = &dag.instrs_partition[new_instruction.i]; - if let Op::Lut { .. } = dag.operators[new_instruction.i] { + let new_instr_part = &dag.instrs_partition[new_instruction.0]; + if let Op::Lut { .. } = dag.operators[new_instruction.0] { let ks_dst = new_instr_part.instruction_partition; partition = Some(ks_dst); #[allow(clippy::match_on_vec_items)] @@ -201,7 +201,7 @@ pub fn original_instrs_partition( ); } let partition = - partition.unwrap_or(dag.instrs_partition[new_instructions[0].i].instruction_partition); + partition.unwrap_or(dag.instrs_partition[new_instructions[0].0].instruction_partition); let input_partition = input_partition.unwrap_or(partition); let merged = keys_spec::InstructionKeys { input_key: big_keys[input_partition].identifier, @@ -218,7 +218,7 @@ pub fn original_instrs_partition( } fn out_variance( - op: &unparametrized::UnparameterizedOperator, + op: &Operator, out_shapes: &[Shape], out_variances: &[Vec], nb_partitions: usize, @@ -232,12 +232,12 @@ fn out_variance( // one variance per partition, in case the result is converted let partition = instr_partition.instruction_partition; let out_variance_of = |input: &OperatorIndex| { - assert!(input.i < out_variances.len()); - assert!(partition < out_variances[input.i].len()); - assert!(out_variances[input.i][partition] != SymbolicVariance::ZERO); - assert!(!out_variances[input.i][partition].coeffs.values[0].is_nan()); - assert!(out_variances[input.i][partition].partition != usize::MAX); - out_variances[input.i][partition].clone() + assert!(input.0 < out_variances.len()); + assert!(partition < out_variances[input.0].len()); + assert!(out_variances[input.0][partition] != SymbolicVariance::ZERO); + assert!(!out_variances[input.0][partition].coeffs.values[0].is_nan()); + assert!(out_variances[input.0][partition].partition != usize::MAX); + out_variances[input.0][partition].clone() }; let max_variance = |acc: SymbolicVariance, input: SymbolicVariance| acc.max(&input); let variance = match op { @@ -293,7 +293,7 @@ fn out_variance( } fn out_variances( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, nb_partitions: usize, instrs_partition: &[InstructionPartition], input_override: &Option>, @@ -315,7 +315,7 @@ fn out_variances( } fn variance_constraint( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, noise_config: &NoiseBoundConfig, partition: PartitionIndex, op_i: usize, @@ -336,18 +336,18 @@ fn variance_constraint( #[allow(clippy::float_cmp)] #[allow(clippy::match_on_vec_items)] fn collect_all_variance_constraints( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, noise_config: &NoiseBoundConfig, instrs_partition: &[InstructionPartition], out_variances: &[Vec], ) -> Vec { let mut constraints = vec![]; - for (op_i, op) in dag.operators.iter().enumerate() { - let partition = instrs_partition[op_i].instruction_partition; - if let Op::Lut { input, .. } = op { - let precision = dag.out_precisions[input.i]; + for op in dag.get_operators_iter() { + let partition = instrs_partition[op.id.0].instruction_partition; + if let Op::Lut { input, .. } = op.operator { + let precision = dag.out_precisions[input.0]; let dst_partition = partition; - let src_partition = match instrs_partition[op_i].inputs_transition[0] { + let src_partition = match instrs_partition[op.id.0].inputs_transition[0] { None => dst_partition, Some(Transition::Internal { src_partition }) => { assert!(src_partition != dst_partition); @@ -355,7 +355,7 @@ fn collect_all_variance_constraints( } Some(Transition::Additional { src_partition }) => { assert!(src_partition != dst_partition); - let variance = &out_variances[input.i][dst_partition]; + let variance = &out_variances[input.0][dst_partition]; assert!( variance.coeff_partition_keyswitch_to_big(src_partition, dst_partition) == 1.0 @@ -363,7 +363,7 @@ fn collect_all_variance_constraints( dst_partition } }; - let variance = &out_variances[input.i][src_partition].clone(); + let variance = &out_variances[input.0][src_partition].clone(); let variance = variance .after_partition_keyswitch_to_small(src_partition, dst_partition) .after_modulus_switching(partition); @@ -371,19 +371,19 @@ fn collect_all_variance_constraints( dag, noise_config, partition, - op_i, + op.id.0, precision, variance, )); } - if dag.is_output_node(op_i) { - let precision = dag.out_precisions[op_i]; - let variance = out_variances[op_i][partition].clone(); + if op.is_output() { + let precision = dag.out_precisions[op.id.0]; + let variance = out_variances[op.id.0][partition].clone(); constraints.push(variance_constraint( dag, noise_config, partition, - op_i, + op.id.0, precision, variance, )); @@ -394,15 +394,15 @@ fn collect_all_variance_constraints( #[allow(clippy::match_on_vec_items)] fn operations_counts( - dag: &unparametrized::OperationDag, - op: &unparametrized::UnparameterizedOperator, + dag: &unparametrized::Dag, + op: &Operator, nb_partitions: usize, instr_partition: &InstructionPartition, ) -> OperationsCount { let mut counts = OperationsValue::zero(nb_partitions); if let Op::Lut { input, .. } = op { let partition = instr_partition.instruction_partition; - let nb_lut = dag.out_shapes[input.i].flat_size() as f64; + let nb_lut = dag.out_shapes[input.0].flat_size() as f64; let src_partition = match instr_partition.inputs_transition[0] { Some(Transition::Internal { src_partition }) => src_partition, Some(Transition::Additional { .. }) | None => partition, @@ -417,7 +417,7 @@ fn operations_counts( } fn collect_operations_count( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, nb_partitions: usize, instrs_partition: &[InstructionPartition], ) -> Vec { @@ -446,12 +446,12 @@ pub mod tests { }; use crate::optimization::dag::solo_key::analyze::tests::CONFIG; - pub fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag { + pub fn analyze(dag: &unparametrized::Dag) -> AnalyzedDag { analyze_with_preferred(dag, LOW_PRECISION_PARTITION) } pub fn analyze_with_preferred( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, default_partition: PartitionIndex, ) -> AnalyzedDag { let p_cut = PartitionCut::for_each_precision(dag); @@ -524,21 +524,19 @@ pub mod tests { #[test] fn test_composition_with_inputs_only() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let _ = dag.add_input(1, Shape::number()); let p_cut = PartitionCut::for_each_precision(&dag); - dag.detect_outputs(); let res = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true); assert!(res.is_ok()); } #[test] fn test_composition_1_partition() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(1, Shape::number()); let _ = dag.add_lut(input1, FunctionTable::UNKWOWN, 2); let p_cut = PartitionCut::for_each_precision(&dag); - dag.detect_outputs(); let dag = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true).unwrap(); assert!(dag.nb_partitions == 1); @@ -556,13 +554,12 @@ pub mod tests { #[test] fn test_composition_2_partitions() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(3, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 6); let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let input2 = dag.add_dot([input1, lut3], [1, 1]); let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); - dag.detect_outputs(); let analyzed_dag = super::analyze(&dag, &CONFIG, &None, LOW_PRECISION_PARTITION, true).unwrap(); assert_eq!(analyzed_dag.nb_partitions, 2); @@ -597,7 +594,7 @@ pub mod tests { #[test] fn test_composition_3_partitions() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(3, Shape::number()); let input2 = dag.add_input(13, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 6); @@ -606,7 +603,6 @@ pub mod tests { let b = dag.add_dot([input1, lut3], [1, 1]); let _ = dag.add_lut(a, FunctionTable::UNKWOWN, 3); let _ = dag.add_lut(b, FunctionTable::UNKWOWN, 3); - dag.detect_outputs(); let analyzed_dag = super::analyze(&dag, &CONFIG, &None, 1, true).unwrap(); assert_eq!(analyzed_dag.nb_partitions, 3); let actual_constraint_strings = analyzed_dag @@ -643,7 +639,7 @@ pub mod tests { #[allow(clippy::needless_range_loop)] #[test] fn test_lut_sequence() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(8, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 1); @@ -658,12 +654,11 @@ pub mod tests { LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION, ]; - dag.detect_outputs(); let dag = analyze(&dag); assert!(dag.nb_partitions == 2); - for op_i in input1.i..=lut5.i { + for op_i in input1.0..=lut5.0 { let p = &dag.instrs_partition[op_i]; - let is_input = op_i == input1.i; + let is_input = op_i == input1.0; assert!(p.instruction_partition == partitions[op_i]); if is_input { assert_input_on(&dag, p.instruction_partition, op_i, 1.0); @@ -677,7 +672,7 @@ pub mod tests { #[test] fn test_levelled_op() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let out_shape = Shape::number(); let manp = 8.0; let input1 = dag.add_input(8, Shape::number()); @@ -690,7 +685,6 @@ pub mod tests { &out_shape, "comment", ); - dag.detect_outputs(); let dag = analyze(&dag); assert!(dag.nb_partitions == 1); } @@ -704,18 +698,17 @@ pub mod tests { fn test_rounded_v3_first_layer_and_second_layer() { let acc_precision = 16; let precision = 8; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(acc_precision, Shape::number()); let rounded1 = dag.add_expanded_round(input1, precision); let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let rounded2 = dag.add_expanded_round(lut1, precision); let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision); - dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); show_partitionning(&old_dag, &dag.instrs_partition); // First layer is fully LOW_PRECISION_PARTITION - for op_i in input1.i..lut1.i { + for op_i in input1.0..lut1.0 { let p = LOW_PRECISION_PARTITION; let sb = &dag.out_variances[op_i][p]; assert!(sb.coeff_input(p) >= 1.0 || sb.coeff_pbs(p) >= 1.0); @@ -725,10 +718,10 @@ pub mod tests { } // First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION let p = HIGH_PRECISION_PARTITION; - let sb = &dag.out_variances[lut1.i][p]; + let sb = &dag.out_variances[lut1.0][p]; assert!(sb.coeff_input(p) == 0.0); assert!(sb.coeff_pbs(p) == 1.0); - let sb_after_fast_ks = &dag.out_variances[lut1.i][LOW_PRECISION_PARTITION]; + let sb_after_fast_ks = &dag.out_variances[lut1.0][LOW_PRECISION_PARTITION]; assert!( sb_after_fast_ks.coeff_partition_keyswitch_to_big( HIGH_PRECISION_PARTITION, @@ -736,7 +729,7 @@ pub mod tests { ) == 1.0 ); // The next rounded is on LOW_PRECISION_PARTITION but base noise can comes from HIGH_PRECISION_PARTITION + FKS - for op_i in (lut1.i + 1)..lut2.i { + for op_i in (lut1.0 + 1)..lut2.0 { assert!(LOW_PRECISION_PARTITION == dag.instrs_partition[op_i].instruction_partition); let p = LOW_PRECISION_PARTITION; let sb = &dag.out_variances[op_i][p]; @@ -762,9 +755,9 @@ pub mod tests { } } assert!(nan_symbolic_variance( - &dag.out_variances[lut2.i][LOW_PRECISION_PARTITION] + &dag.out_variances[lut2.0][LOW_PRECISION_PARTITION] )); - let sb = &dag.out_variances[lut2.i][HIGH_PRECISION_PARTITION]; + let sb = &dag.out_variances[lut2.0][HIGH_PRECISION_PARTITION]; assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0); } @@ -773,23 +766,22 @@ pub mod tests { fn test_rounded_v3_classic_first_layer_second_layer() { let acc_precision = 16; let precision = 8; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let free_input1 = dag.add_input(precision, Shape::number()); let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); - dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); show_partitionning(&old_dag, &dag.instrs_partition); // First layer is fully HIGH_PRECISION_PARTITION assert!( - dag.out_variances[free_input1.i][HIGH_PRECISION_PARTITION] + dag.out_variances[free_input1.0][HIGH_PRECISION_PARTITION] .coeff_input(HIGH_PRECISION_PARTITION) == 1.0 ); // First layer tlu - let sb = &dag.out_variances[input1.i][HIGH_PRECISION_PARTITION]; + let sb = &dag.out_variances[input1.0][HIGH_PRECISION_PARTITION]; assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); assert!( @@ -797,7 +789,7 @@ pub mod tests { == 0.0 ); // The same cyphertext exists in another partition with additional noise due to fast keyswitch - let sb = &dag.out_variances[input1.i][LOW_PRECISION_PARTITION]; + let sb = &dag.out_variances[input1.0][LOW_PRECISION_PARTITION]; assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); assert!( @@ -808,7 +800,7 @@ pub mod tests { // Second layer let mut first_bit_extract_verified = false; let mut first_bit_erase_verified = false; - for op_i in (input1.i + 1)..rounded1.i { + for op_i in (input1.0 + 1)..rounded1.0 { if let Op::Dot { weights, inputs, .. } = &dag.operators[op_i] @@ -817,7 +809,7 @@ pub mod tests { let first_bit_extract = bit_extract && !first_bit_extract_verified; let bit_erase = weights.values == [1, -1]; let first_bit_erase = bit_erase && !first_bit_erase_verified; - let input0_sb = &dag.out_variances[inputs[0].i][LOW_PRECISION_PARTITION]; + let input0_sb = &dag.out_variances[inputs[0].0][LOW_PRECISION_PARTITION]; let input0_coeff_pbs_high = input0_sb.coeff_pbs(HIGH_PRECISION_PARTITION); let input0_coeff_pbs_low = input0_sb.coeff_pbs(LOW_PRECISION_PARTITION); let input0_coeff_fks = input0_sb.coeff_partition_keyswitch_to_big( @@ -835,7 +827,7 @@ pub mod tests { assert!(input0_coeff_fks == 1.0); } else if bit_erase { first_bit_erase_verified |= first_bit_erase; - let input1_sb = &dag.out_variances[inputs[1].i][LOW_PRECISION_PARTITION]; + let input1_sb = &dag.out_variances[inputs[1].0][LOW_PRECISION_PARTITION]; let input1_coeff_pbs_high = input1_sb.coeff_pbs(HIGH_PRECISION_PARTITION); let input1_coeff_pbs_low = input1_sb.coeff_pbs(LOW_PRECISION_PARTITION); let input1_coeff_fks = input1_sb.coeff_partition_keyswitch_to_big( @@ -863,12 +855,11 @@ pub mod tests { fn test_rounded_v3_classic_first_layer_second_layer_constraints() { let acc_precision = 7; let precision = 4; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let free_input1 = dag.add_input(precision, Shape::number()); let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); - dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); show_partitionning(&old_dag, &dag.instrs_partition); @@ -924,13 +915,12 @@ pub mod tests { fn test_rounded_v1_classic_first_layer_second_layer_constraints() { let acc_precision = 7; let precision = 4; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let free_input1 = dag.add_input(precision, Shape::number()); let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); // let input1 = dag.add_input(acc_precision, Shape::number()); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); - dag.detect_outputs(); let old_dag = dag; let dag = analyze_with_preferred(&old_dag, HIGH_PRECISION_PARTITION); show_partitionning(&old_dag, &dag.instrs_partition); @@ -983,12 +973,11 @@ pub mod tests { fn test_rounded_v3_classic_first_layer_second_layer_complexity() { let acc_precision = 7; let precision = 4; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let free_input1 = dag.add_input(precision, Shape::number()); let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); - dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); // Partition 0 @@ -1033,7 +1022,7 @@ pub mod tests { #[test] fn test_high_partition_number() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let max_precision = 10; let mut lut_input = dag.add_input(max_precision, Shape::number()); for out_precision in (1..=max_precision).rev() { @@ -1042,7 +1031,6 @@ pub mod tests { _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, 1); let precisions: Vec<_> = (1..=max_precision).collect(); let p_cut = PartitionCut::from_precisions(&precisions); - dag.detect_outputs(); let dag = super::analyze( &dag, &CONFIG, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 575ecfe1c8..661b7e13bd 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -12,4 +12,3 @@ mod partitions; mod symbolic_variance; mod union_find; mod variance_constraint; -mod visualization; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index 2151c9c782..b60ede5e29 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -1,7 +1,7 @@ // OPT: cache for fks and verified pareto use concrete_cpu_noise_model::gaussian_noise::noise::modulus_switching::estimate_modulus_switching_noise_with_binary_key; -use crate::dag::unparametrized; +use crate::dag::unparametrized::Dag; use crate::noise_estimator::error; use crate::optimization; use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace}; @@ -900,7 +900,7 @@ fn cross_partition(nb_partitions: usize) -> impl Iterator #[allow(clippy::too_many_lines, clippy::missing_errors_doc)] pub fn optimize( - dag: &unparametrized::OperationDag, + dag: &Dag, config: Config, search_space: &SearchSpace, persistent_caches: &PersistDecompCaches, @@ -1156,7 +1156,7 @@ fn sanity_check( } pub fn optimize_to_circuit_solution( - dag: &unparametrized::OperationDag, + dag: &Dag, config: Config, search_space: &SearchSpace, persistent_caches: &PersistDecompCaches, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index 835916e48d..5784676693 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -46,7 +46,7 @@ fn default_config() -> Config<'static> { } fn optimize( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, p_cut: &Option, default_partition: usize, ) -> Option { @@ -63,11 +63,11 @@ fn optimize( .map_or(None, |v| Some(v.1)) } -fn optimize_single(dag: &unparametrized::OperationDag) -> Option { +fn optimize_single(dag: &unparametrized::Dag) -> Option { optimize(dag, &Some(PartitionCut::empty()), LOW_PARTITION) } -fn equiv_single(dag: &unparametrized::OperationDag) -> Option { +fn equiv_single(dag: &unparametrized::Dag) -> Option { let sol_mono = solo_key::optimize::tests::optimize(dag); let sol_multi = optimize_single(dag); if sol_mono.best_solution.is_none() != sol_multi.is_none() { @@ -121,11 +121,10 @@ fn optimize_simple_parameter_rounded_lut_2_layers() { } fn equiv_2_single( - dag_multi: &unparametrized::OperationDag, - dag_1: &unparametrized::OperationDag, - dag_2: &unparametrized::OperationDag, + dag_multi: &unparametrized::Dag, + dag_1: &unparametrized::Dag, + dag_2: &unparametrized::Dag, ) -> Option { - eprintln!("{dag_multi}"); let sol_single_1 = solo_key::optimize::tests::optimize(dag_1); let sol_single_2 = solo_key::optimize::tests::optimize(dag_2); let sol_multi = optimize(dag_multi, &None, LOW_PARTITION); @@ -177,11 +176,8 @@ fn optimize_multi_independant_2_precisions() { let noise_factor = manp as f64; let mut dag_multi = v0_dag(sum_size, precision1, noise_factor); add_v0_dag(&mut dag_multi, sum_size, precision2, noise_factor); - dag_multi.detect_outputs(); - let mut dag_1 = v0_dag(sum_size, precision1, noise_factor); - dag_1.detect_outputs(); - let mut dag_2 = v0_dag(sum_size, precision2, noise_factor); - dag_2.detect_outputs(); + let dag_1 = v0_dag(sum_size, precision1, noise_factor); + let dag_2 = v0_dag(sum_size, precision2, noise_factor); if let Some(equiv) = equiv_2_single(&dag_multi, &dag_1, &dag_2) { assert!(equiv, "FAILED ON {precision1} {precision2} {manp}"); } else { @@ -196,8 +192,8 @@ fn dag_lut_sum_of_2_partitions_2_layer( precision1: u8, precision2: u8, final_lut: bool, -) -> unparametrized::OperationDag { - let mut dag = unparametrized::OperationDag::new(); +) -> unparametrized::Dag { + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(precision1, Shape::number()); let input2 = dag.add_input(precision2, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision1); @@ -208,7 +204,6 @@ fn dag_lut_sum_of_2_partitions_2_layer( if final_lut { _ = dag.add_lut(dot, FunctionTable::UNKWOWN, precision1); } - dag.detect_outputs(); dag } @@ -338,19 +333,16 @@ fn optimize_multi_independant_2_partitions_finally_added_and_luted() { } } -fn optimize_rounded(dag: &unparametrized::OperationDag) -> Option { +fn optimize_rounded(dag: &unparametrized::Dag) -> Option { let p_cut = Some(PartitionCut::from_precisions(&[1, 128])); let default_partition = 0; optimize(dag, &p_cut, default_partition) } -fn dag_rounded_lut_2_layers( - accumulator_precision: usize, - precision: usize, -) -> unparametrized::OperationDag { +fn dag_rounded_lut_2_layers(accumulator_precision: usize, precision: usize) -> unparametrized::Dag { let out_precision = accumulator_precision as u8; let rounded_precision = precision as u8; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(precision as u8, Shape::number()); let rounded1 = dag.add_expanded_rounded_lut( input1, @@ -418,7 +410,7 @@ fn test_optimize_v3_expanded_round_16_6() { #[test] fn optimize_v3_direct_round() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(16, Shape::number()); _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16); let sol = optimize_rounded(&dag).unwrap(); @@ -437,7 +429,7 @@ fn optimize_v3_direct_round() { fn optimize_sign_extract() { let precision = 8; let high_precision = 16; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let complexity = LevelledComplexity::ZERO; let free_small_input1 = dag.add_input(precision, Shape::number()); let small_input1 = dag.add_lut(free_small_input1, FunctionTable::UNKWOWN, precision); @@ -467,7 +459,7 @@ fn test_partition_chain(decreasing: bool) { // tlu chain with decreasing precision (decreasing partition index) // check that increasing partitionning gaves faster solutions // check solution has the right structure - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let min_precision = 6; let max_precision = 8; let mut input_precisions: Vec<_> = (min_precision..=max_precision).collect(); @@ -588,7 +580,7 @@ fn test_chained_partitions_non_feasible_single_params() { // generate hard circuit, non feasible with single parameters let precisions = [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; // Note: reversing chain have issues for connecting lower bits to 7 bits, there may be no feasible solution - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let mut lut_input = dag.add_input(precisions[0], Shape::number()); for out_precision in precisions { let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64; @@ -614,12 +606,11 @@ fn test_chained_partitions_non_feasible_single_params() { #[test] fn test_multi_rounded_fks_coherency() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(16, Shape::number()); let reduced_8 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 8); let reduced_4 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); _ = dag.add_dot([reduced_8, reduced_4], [1, 1]); - dag.detect_outputs(); let sol = optimize(&dag, &None, 0); assert!(sol.is_some()); let sol = sol.unwrap(); @@ -633,7 +624,7 @@ fn test_multi_rounded_fks_coherency() { #[test] fn test_levelled_only() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let _ = dag.add_input(22, Shape::number()); let config = default_config(); let search_space = SearchSpace::default_cpu(); @@ -648,14 +639,13 @@ fn test_levelled_only() { #[test] fn test_big_secret_key_sharing() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(4, Shape::number()); let input2 = dag.add_input(5, Shape::number()); let input2 = dag.add_dot([input2], [128]); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); let _ = dag.add_dot([lut1, lut2], [16, 1]); - dag.detect_outputs(); let config_sharing = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, @@ -699,14 +689,13 @@ fn test_big_secret_key_sharing() { #[test] fn test_big_and_small_secret_key() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(4, Shape::number()); let input2 = dag.add_input(5, Shape::number()); let input2 = dag.add_dot([input2], [128]); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); let _ = dag.add_dot([lut1, lut2], [16, 1]); - dag.detect_outputs(); let config_sharing = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, @@ -751,13 +740,12 @@ fn test_big_and_small_secret_key() { #[test] fn test_composition_2_partitions() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(3, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 6); let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let input2 = dag.add_dot([input1, lut3], [1, 1]); let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); - dag.detect_outputs(); let normal_config = default_config(); let composed_config = Config { composable: true, @@ -783,12 +771,11 @@ fn test_composition_2_partitions() { #[test] fn test_composition_1_partition_not_composable() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(8, Shape::number()); let input1 = dag.add_dot([input1], [1 << 16]); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); let _ = dag.add_dot([lut1], [1 << 16]); - dag.detect_outputs(); let normal_config = default_config(); let composed_config = Config { composable: true, @@ -812,12 +799,11 @@ fn test_composition_1_partition_not_composable() { fn test_maximal_multi() { let config = default_config(); let search_space = SearchSpace::default_cpu(); - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input = dag.add_input(8, Shape::number()); let lut1 = dag.add_lut(input, FunctionTable::UNKWOWN, 8u8); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 8u8); _ = dag.add_dot([lut2], [1 << 16]); - dag.detect_outputs(); let sol = optimize(&dag, &None, 0).unwrap(); assert!(sol.macro_params.len() == 1); @@ -850,7 +836,7 @@ fn test_maximal_multi() { fn test_bug_with_zero_noise() { let complexity = LevelledComplexity::ZERO; let out_shape = Shape::number(); - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let v0 = dag.add_input(2, &out_shape); let v1 = dag.add_levelled_op([v0], complexity, 0.0, &out_shape, "comment"); let v2 = dag.add_levelled_op([v1], complexity, 1.0, &out_shape, "comment"); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs index 98d9e0865c..049ed02f98 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs @@ -1,4 +1,4 @@ -use crate::dag::unparametrized::OperationDag; +use crate::dag::unparametrized::Dag; use crate::optimization::config::{Config, SearchSpace}; use crate::optimization::dag::multi_parameters::keys_spec::CircuitSolution; use crate::optimization::dag::multi_parameters::optimize::optimize_to_circuit_solution as native_optimize; @@ -26,7 +26,7 @@ fn best_complexity_solution(native: CircuitSolution, crt: CircuitSolution) -> Ci } fn crt_optimize( - dag: &OperationDag, + dag: &Dag, config: Config, search_space: &SearchSpace, default_log_norm2_woppbs: f64, @@ -54,7 +54,7 @@ fn crt_optimize( } pub fn optimize( - dag: &OperationDag, + dag: &Dag, config: Config, search_space: &SearchSpace, encoding: Encoding, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs index 20b93a304e..e25d4532f1 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs @@ -49,21 +49,21 @@ impl PartitionCut { if self.rnorm2.is_empty() { return f64::MAX; } - assert!(!self.rnorm2[op_i.i].is_nan()); - self.rnorm2[op_i.i] + assert!(!self.rnorm2[op_i.0].is_nan()); + self.rnorm2[op_i.0] } pub fn partition( &self, - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, op_i: OperatorIndex, ) -> Option { - let op = &dag.operators[op_i.i]; + let op = &dag.operators[op_i.0]; match op { Operator::Lut { input, .. } => { assert!(!self.p_cut.is_empty()); for (partition, &(precision_cut, norm2_cut)) in self.p_cut.iter().enumerate() { - if dag.out_precisions[input.i] <= precision_cut + if dag.out_precisions[input.0] <= precision_cut && self.rnorm2(op_i) <= norm2_cut { return Some(partition); @@ -75,19 +75,19 @@ impl PartitionCut { } } - pub fn for_each_precision(dag: &unparametrized::OperationDag) -> Self { + pub fn for_each_precision(dag: &unparametrized::Dag) -> Self { let (dag, _) = expand_round_and_index_map(dag); let mut lut_in_precisions: HashSet<_> = HashSet::default(); for op in &dag.operators { if let Operator::Lut { input, .. } = op { - _ = lut_in_precisions.insert(dag.out_precisions[input.i]); + _ = lut_in_precisions.insert(dag.out_precisions[input.0]); } } let precisions: Vec<_> = lut_in_precisions.iter().copied().collect(); Self::from_precisions(&precisions) } - pub fn maximal_partitionning(original_dag: &unparametrized::OperationDag) -> Self { + pub fn maximal_partitionning(original_dag: &unparametrized::Dag) -> Self { // Note: only keep one 0-bits, partition as the compiler will not support multi-parameter round // partition based on input precision and output log norm2 let (dag, rewrited) = expand_round_and_index_map(original_dag); @@ -95,7 +95,7 @@ impl PartitionCut { for (round_i, op) in original_dag.operators.iter().enumerate() { if let Operator::Round { .. } = op { for op in &rewrited[round_i] { - let already = round_index.insert(op.i, round_i); + let already = round_index.insert(op.0, round_i); assert!(already.is_none()); } } @@ -112,12 +112,12 @@ impl PartitionCut { Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { let mut origins = HashSet::default(); for input in inputs { - origins.extend(&noise_origins[input.i]); + origins.extend(&noise_origins[input.0]); } noise_origins[op_i] = origins; } Operator::UnsafeCast { input, .. } => { - noise_origins[op_i] = noise_origins[input.i].clone(); + noise_origins[op_i] = noise_origins[input.0].clone(); } // origins Operator::Lut { .. } => { @@ -136,16 +136,16 @@ impl PartitionCut { let mut lut_partition: HashSet<_> = HashSet::default(); for dest in &dag.operators { if let Operator::Lut { input, .. } = dest { - for &origin in &noise_origins[input.i] { - let norm2 = out_norm2(input.i); + for &origin in &noise_origins[input.0] { + let norm2 = out_norm2(input.0); max_output_norm2[origin] = max_output_norm2[origin].max(norm2); assert!(!max_output_norm2[origin].is_nan()); } } } - for op_i in dag.get_output_index_iter() { - for &origin in &noise_origins[op_i] { - max_output_norm2[origin] = max_output_norm2[origin].max(out_norm2(op_i)); + for op in dag.get_output_operators_iter() { + for &origin in &noise_origins[op.id.0] { + max_output_norm2[origin] = max_output_norm2[origin].max(out_norm2(op.id.0)); assert!(!max_output_norm2[origin].is_nan()); } } @@ -153,7 +153,7 @@ impl PartitionCut { // reassociate all lut's output_norm2 and precisions for (op_i, output_norm2) in max_output_norm2.iter_mut().enumerate() { if let Operator::Lut { input, .. } = dag.operators[op_i] { - let input_precision = dag.out_precisions[input.i]; + let input_precision = dag.out_precisions[input.0]; let output_precision = dag.out_precisions[op_i] as i32; let delta_precision = output_precision - input_precision as i32; assert!(!output_norm2.is_nan()); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index b547211ff2..83707f3719 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -46,7 +46,7 @@ impl Blocks { // Extract block of instructions connected by levelled ops. // This facilitates reasonning about conflicts on levelled ops. #[allow(clippy::match_same_arms)] -fn extract_levelled_block(dag: &unparametrized::OperationDag, composable: bool) -> Blocks { +fn extract_levelled_block(dag: &unparametrized::Dag, composable: bool) -> Blocks { let mut uf = UnionFind::new(dag.operators.len()); for (op_i, op) in dag.operators.iter().enumerate() { match op { @@ -55,10 +55,10 @@ fn extract_levelled_block(dag: &unparametrized::OperationDag, composable: bool) // Block entry point and pre-exit point Op::Lut { .. } => (), // Connectors - Op::UnsafeCast { input, .. } => uf.union(input.i, op_i), + Op::UnsafeCast { input, .. } => uf.union(input.0, op_i), Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { for input in inputs { - uf.union(input.i, op_i); + uf.union(input.0, op_i); } } Op::Round { .. } => unreachable!("Round should have been expanded"), @@ -67,9 +67,10 @@ fn extract_levelled_block(dag: &unparametrized::OperationDag, composable: bool) if composable { // Without knowledge of how outputs are forwarded to inputs, we can't do better than putting // all inputs and outputs in the same partition. - let mut input_iter = dag.get_input_index_iter(); + let mut input_iter = dag.get_input_operators_iter().map(|op| op.id.0); let first_inp = input_iter.next().unwrap(); - dag.get_output_index_iter() + dag.get_output_operators_iter() + .map(|op| op.id.0) .chain(input_iter) .for_each(|ind| uf.union(first_inp, ind)); } @@ -84,7 +85,7 @@ struct BlockConstraints { /* For each levelled block collect BlockConstraints */ fn levelled_blocks_constraints( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, blocks: &Blocks, p_cut: &PartitionCut, ) -> Vec { @@ -92,10 +93,10 @@ fn levelled_blocks_constraints( for (block_i, ops_i) in blocks.blocks.iter().enumerate() { for &op_i in ops_i { let op = &dag.operators[op_i]; - if let Some(partition) = p_cut.partition(dag, OperatorIndex { i: op_i }) { + if let Some(partition) = p_cut.partition(dag, OperatorIndex(op_i)) { _ = constraints_by_block[block_i].forced.insert(partition); if let Some(input) = op_tlu_inputs(op) { - let input_group = blocks.block_of[input.i]; + let input_group = blocks.block_of[input.0]; constraints_by_block[input_group].exit.extend([partition]); } } @@ -115,7 +116,7 @@ fn get_singleton_value(hashset: &HashSet) -> V { *hashset.iter().next().unwrap() } -fn only_1_partition(dag: &unparametrized::OperationDag) -> Partitions { +fn only_1_partition(dag: &unparametrized::Dag) -> Partitions { let mut instrs_partition = vec![InstructionPartition::new(0); dag.operators.len()]; for (op_i, op) in dag.operators.iter().enumerate() { match op { @@ -136,7 +137,7 @@ fn only_1_partition(dag: &unparametrized::OperationDag) -> Partitions { } fn resolve_by_levelled_block( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, p_cut: &PartitionCut, default_partition: PartitionIndex, composable: bool, @@ -194,10 +195,9 @@ fn resolve_by_levelled_block( let group_partition = block_partition_of(op_i); match op { Op::Lut { input, .. } => { - let instruction_partition = - p_cut.partition(dag, OperatorIndex { i: op_i }).unwrap(); + let instruction_partition = p_cut.partition(dag, OperatorIndex(op_i)).unwrap(); instrs_p[op_i].instruction_partition = instruction_partition; - let input_partition = instrs_p[input.i].instruction_partition; + let input_partition = instrs_p[input.0].instruction_partition; instrs_p[op_i].inputs_transition = if input_partition == instruction_partition { vec![None] } else { @@ -214,7 +214,7 @@ fn resolve_by_levelled_block( instrs_p[op_i].instruction_partition = group_partition; instrs_p[op_i].inputs_transition = vec![None; inputs.len()]; for (i, input) in inputs.iter().enumerate() { - let input_partition = instrs_p[input.i].instruction_partition; + let input_partition = instrs_p[input.0].instruction_partition; if group_partition != input_partition { instrs_p[op_i].inputs_transition[i] = Some(Transition::Additional { src_partition: input_partition, @@ -224,7 +224,7 @@ fn resolve_by_levelled_block( } Op::UnsafeCast { input, .. } => { instrs_p[op_i].instruction_partition = group_partition; - let input_partition = instrs_p[input.i].instruction_partition; + let input_partition = instrs_p[input.0].instruction_partition; instrs_p[op_i].inputs_transition = if group_partition == input_partition { vec![None] } else { @@ -248,7 +248,7 @@ fn resolve_by_levelled_block( } pub fn partitionning_with_preferred( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, p_cut: &PartitionCut, default_partition: PartitionIndex, composable: bool, @@ -275,12 +275,12 @@ pub mod tests { PartitionCut::from_precisions(&[2, 128]) } - fn partitionning_no_p_cut(dag: &unparametrized::OperationDag, composable: bool) -> Partitions { + fn partitionning_no_p_cut(dag: &unparametrized::Dag, composable: bool) -> Partitions { let p_cut = PartitionCut::empty(); partitionning_with_preferred(dag, &p_cut, LOW_PRECISION_PARTITION, composable) } - fn partitionning(dag: &unparametrized::OperationDag, composable: bool) -> Partitions { + fn partitionning(dag: &unparametrized::Dag, composable: bool) -> Partitions { partitionning_with_preferred( dag, &PartitionCut::for_each_precision(dag), @@ -290,7 +290,7 @@ pub mod tests { } fn partitionning_with_preferred( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, p_cut: &PartitionCut, default_partition: usize, composable: bool, @@ -298,10 +298,7 @@ pub mod tests { super::partitionning_with_preferred(dag, p_cut, default_partition, composable) } - pub fn show_partitionning( - dag: &unparametrized::OperationDag, - partitions: &[InstructionPartition], - ) { + pub fn show_partitionning(dag: &unparametrized::Dag, partitions: &[InstructionPartition]) { println!("Dag:"); for (i, op) in dag.operators.iter().enumerate() { let partition = partitions[i].instruction_partition; @@ -334,7 +331,7 @@ pub mod tests { #[test] fn test_1_partition() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(16, Shape::number()); _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); let instrs_partition = partitionning_no_p_cut(&dag, false).instrs_partition; @@ -346,7 +343,7 @@ pub mod tests { #[test] fn test_1_input_2_partitions() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); _ = dag.add_input(1, Shape::number()); let partitions = partitionning(&dag, false); assert!(partitions.nb_partitions == 1); @@ -357,26 +354,25 @@ pub mod tests { #[test] fn test_2_partitions_with_without_compo() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input = dag.add_input(10, Shape::number()); let lut1 = dag.add_lut(input, FunctionTable::UNKWOWN, 2); let output = dag.add_lut(lut1, FunctionTable::UNKWOWN, 10); - dag.detect_outputs(); let partitions = partitionning(&dag, false); assert!( - partitions.instrs_partition[input.i].instruction_partition - != partitions.instrs_partition[output.i].instruction_partition + partitions.instrs_partition[input.0].instruction_partition + != partitions.instrs_partition[output.0].instruction_partition ); let partitions = partitionning(&dag, true); assert!( - partitions.instrs_partition[input.i].instruction_partition - == partitions.instrs_partition[output.i].instruction_partition + partitions.instrs_partition[input.0].instruction_partition + == partitions.instrs_partition[output.0].instruction_partition ); } #[test] fn test_2_lut_sequence() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let mut expected_partitions = vec![]; let input1 = dag.add_input(8, Shape::number()); expected_partitions.push(HIGH_PRECISION_PARTITION); @@ -393,7 +389,7 @@ pub mod tests { let partitions = partitionning(&dag, false); assert!(partitions.nb_partitions == 2); let instrs_partition = partitions.instrs_partition; - let consider = |op_i: OperatorIndex| &instrs_partition[op_i.i]; + let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0]; show_partitionning(&dag, &instrs_partition); assert!(consider(input1).instruction_partition == HIGH_PRECISION_PARTITION); // no constraint assert!(consider(lut1).instruction_partition == expected_partitions[1]); @@ -406,7 +402,7 @@ pub mod tests { #[test] fn test_mixed_dot_no_conflict_low() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(8, Shape::number()); let input2 = dag.add_input(1, Shape::number()); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); @@ -417,7 +413,7 @@ pub mod tests { #[test] fn test_mixed_dot_no_conflict_high() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(8, Shape::number()); let input2 = dag.add_input(1, Shape::number()); let lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, 1); @@ -428,14 +424,14 @@ pub mod tests { #[test] fn test_mixed_dot_conflict() { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(8, Shape::number()); let input2 = dag.add_input(1, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); let dot = dag.add_dot([lut1, lut2], Weights::from([1, 1])); let partitions = partitionning(&dag, false); - let consider = |op_i: OperatorIndex| &partitions.instrs_partition[op_i.i]; + let consider = |op_i: OperatorIndex| &partitions.instrs_partition[op_i.0]; // input1 let p = consider(input1); { @@ -484,7 +480,7 @@ pub mod tests { fn test_rounded_v3_first_layer_and_second_layer() { let acc_precision = 8; let precision = 6; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(acc_precision, Shape::number()); let rounded1 = dag.add_expanded_round(input1, precision); let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); @@ -493,13 +489,13 @@ pub mod tests { let partitions = partitionning(&dag, false); let consider = |op_i| &partitions.instrs_partition[op_i]; // First layer is fully LOW_PRECISION_PARTITION - for op_i in input1.i..lut1.i { + for op_i in input1.0..lut1.0 { let p = consider(op_i); assert!(p.instruction_partition == LOW_PRECISION_PARTITION); assert!(p.no_transition()); } // First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION - let p = consider(lut1.i); + let p = consider(lut1.0); { assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); assert!( @@ -512,11 +508,11 @@ pub mod tests { })] ); }; - for op_i in (lut1.i + 1)..lut2.i { + for op_i in (lut1.0 + 1)..lut2.0 { let p = consider(op_i); assert!(p.instruction_partition == LOW_PRECISION_PARTITION); } - let p = consider(lut2.i); + let p = consider(lut2.0); { assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); assert!(p.alternative_output_representation.is_empty()); @@ -533,12 +529,12 @@ pub mod tests { fn test_rounded_v3_classic_first_layer_second_layer() { let acc_precision = 8; let precision = 6; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let free_input1 = dag.add_input(precision, Shape::number()); let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); - let first_layer = free_input1.i..=input1.i; + let first_layer = free_input1.0..=input1.0; let rounded1 = dag.add_expanded_round(input1, precision); - let rounded_layer: Vec<_> = ((input1.i + 1)..rounded1.i).collect(); + let rounded_layer: Vec<_> = ((input1.0 + 1)..rounded1.0).collect(); let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let partitions = partitionning(&dag, false); let consider = |op_i: usize| &partitions.instrs_partition[op_i]; @@ -549,7 +545,7 @@ pub mod tests { assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); } // input is converted with a fast keyswitch to LOW_PRECISION_PARTITION - let p = consider(input1.i); + let p = consider(input1.0); assert!(p.alternative_output_representation == HashSet::from([LOW_PRECISION_PARTITION])); let read_converted = Some(Transition::Additional { src_partition: HIGH_PRECISION_PARTITION, @@ -579,7 +575,7 @@ pub mod tests { assert!(first_bit_erase_verified); // Second layer, lut part is HIGH_PRECISION_PARTITION // and use an internal conversion - let p = consider(lut1.i); + let p = consider(lut1.0); assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); assert!( p.inputs_transition[0] @@ -593,12 +589,12 @@ pub mod tests { fn test_rounded_v1_classic_first_layer_second_layer() { let acc_precision = 8; let precision = 6; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let free_input1 = dag.add_input(precision, Shape::number()); let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); - let first_layer = free_input1.i..=input1.i; + let first_layer = free_input1.0..=input1.0; let rounded1 = dag.add_expanded_round(input1, precision); - let rounded_layer = (input1.i + 1)..rounded1.i; + let rounded_layer = (input1.0 + 1)..rounded1.0; let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let partitions = partitionning_with_preferred(&dag, &default_p_cut(), HIGH_PRECISION_PARTITION, false); @@ -610,7 +606,7 @@ pub mod tests { assert!(consider(op_i).instruction_partition == HIGH_PRECISION_PARTITION); } // input is converted with a fast keyswitch to LOW_PRECISION_PARTITION - assert!(consider(input1.i) + assert!(consider(input1.0) .alternative_output_representation .is_empty()); let read_converted = Some(Transition::Additional { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs index 1a01b0ae22..f3fdc8717c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs @@ -189,10 +189,10 @@ impl SymbolicVariance { impl fmt::Display for SymbolicVariance { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self == &Self::ZERO { - write!(f, "ZERO x σ²")?; + return write!(f, "ZERO x σ²"); } if self.coeffs[0].is_nan() { - write!(f, "NAN x σ²")?; + return write!(f, "NAN x σ²"); } let mut add_plus = ""; for src_partition in 0..self.nb_partitions() { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs deleted file mode 100644 index a525d2fec1..0000000000 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::iter::once; -use std::path::PathBuf; -use std::process::Command; - -use super::analyze::AnalyzedDag; -use super::partitions::Transition; -use crate::dag::operator::Operator; - -const COLORSCHEME: &str = "set19"; - -#[allow(unused)] -/// Saves a graphviz representation of the analyzed dag. -/// -/// The image is stored in your temp folder as `concrete_optimizer_dbg.png`. -pub fn save_dag_dot(dag: &AnalyzedDag) { - let path = write_dot_svg(&analyzed_dag_to_dot_string(dag), None); - println!("Analyzed dag visible at path {}", path.to_str().unwrap()); -} - -#[allow(unused)] -/// Dump the dag dot code and panic. -/// -/// For debug purpose essentially. -pub fn dump_dag_dot(dag: &AnalyzedDag) { - println!("{}", analyzed_dag_to_dot_string(dag)); - panic!(); -} - -fn write_dot_svg(dot: &str, maybe_path: Option) -> PathBuf { - let mut path = maybe_path.unwrap_or_else(std::env::temp_dir); - path.push("concrete_optimizer_dbg.png"); - let _ = Command::new("sh") - .arg("-c") - .arg(format!( - "echo '{}' | dot -Tpng > {}", - dot, - path.to_str().unwrap() - )) - .output() - .expect("Failed to execute dot. Do you have graphviz installed ?"); - path -} - -fn extract_node_inputs(node: &Operator) -> Vec { - node.get_inputs_iter().map(|id| id.i).collect() -} - -fn extract_node_label(node: &Operator, index: usize) -> String { - let input_string = extract_node_inputs(node) - .iter() - .map(|n| format!("%{n}")) - .collect::>() - .join(", "); - match node { - Operator::Input { out_precision, .. } => { - format!("{{%{index} = Input({input_string}) |{{out_precision:|{out_precision:?}}}}}",) - } - Operator::Lut { out_precision, .. } => { - format!("{{%{index} = Lut({input_string}) |{{out_precision:|{out_precision:?}}}}}",) - } - Operator::Dot { .. } => { - format!("{{%{index} = Dot({input_string})}}") - } - Operator::LevelledOp { manp, .. } => { - format!("{{%{index} = LevelledOp({input_string}) |{{manp:|{manp:?}}}}}",) - } - Operator::UnsafeCast { out_precision, .. } => format!( - "{{%{index} = UnsafeCast({input_string}) |{{out_precision:|{out_precision:?}}}}}", - ), - Operator::Round { out_precision, .. } => { - format!("{{%{index} = Round({input_string}) |{{out_precision:|{out_precision:?}}}}}",) - } - } -} - -fn analyzed_dag_to_dot_string(dag: &AnalyzedDag) -> String { - let partitions: Vec = dag - .p_cut - .p_cut - .iter() - .map(|(p, _)| format!("{p}")) - .chain(once(String::new())) - .enumerate() - .map(|(i, pci)| { - format!( - "partition_{i} [label=\"{{ Partition {i} | {{ p_cut: | {pci} }} }}\" fillcolor={}]", - i + 1 - ) - }) - .collect(); - - let mut graph: Vec = vec![]; - let iterator = dag - .operators - .iter() - .zip(dag.instrs_partition.iter()) - .enumerate(); - - for (i, (node, partition)) in iterator { - if let Operator::Lut { input, .. } = node { - let input_partition_color = dag.instrs_partition[input.i].instruction_partition + 1; - let lut_partition_color = partition.instruction_partition + 1; - let input_index = input.i; - let label = extract_node_label(node, i); - let node = format!(" - {input_index} -> ks_{i} [color=\"/{COLORSCHEME}/{input_partition_color}\"]; - subgraph cluster_{i}{{ - label_{i} [label=\"{label}\"]; - ks_{i} [label=\"KS\" fillcolor=\"/{COLORSCHEME}/{input_partition_color}:/{COLORSCHEME}/{lut_partition_color}\"]; - {i} [label=\"MS\\+BR\" fillcolor={lut_partition_color}]; - ks_{i} -> {i} [color=\"/{COLORSCHEME}/{lut_partition_color}\"]; - }} - "); - graph.push(node); - for fks_i in &partition.alternative_output_representation { - let output_partition_color = fks_i + 1; - graph.push(format!(" - fks_{i}_{fks_i} [label=\"FKS\", fillcolor=\"/{COLORSCHEME}/{lut_partition_color}:/{COLORSCHEME}/{output_partition_color}\"]; - {i} -> fks_{i}_{fks_i} [color=\"/{COLORSCHEME}/{lut_partition_color}\"]; - ")); - } - } else { - let partition_index = partition.instruction_partition; - let partition_color = partition_index + 1; - let label = extract_node_label(node, i); - graph.push(format!( - "{i} [label=\"{label}\" fillcolor={partition_color}];", - )); - - for (j, input_index) in extract_node_inputs(node).into_iter().enumerate() { - match partition.inputs_transition.get(j) { - Some(Some(Transition::Additional { .. })) => { - graph.push(format!("fks_{input_index}_{partition_index} -> {i} [color=\"/{COLORSCHEME}/{partition_color}\"]")); - } - _ => { - graph.push(format!( - "{input_index} -> {i} [color=\"/{COLORSCHEME}/{partition_color}\"]", - )); - } - }; - } - } - } - - format!( - " -digraph G {{ -fontname=\"Helvetica,Arial,sans-serif\" -node [fontname=\"Helvetica,Arial,sans-serif\" gradientangle=270] -edge [fontname=\"Helvetica,Arial,sans-serif\"] -rankdir=TB; -node [shape=record colorscheme={COLORSCHEME} style=filled]; -subgraph cluster_0 {{ -label=\"Partitions\" -{} -}} -{} -}}", - partitions.join("\n"), - graph.join("\n"), - ) -} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 5cd97b0169..48b5d44a24 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -1,21 +1,18 @@ use super::symbolic_variance::{SymbolicVariance, VarianceOrigin}; use crate::dag::operator::{ - dot_kind, DotKind, LevelledComplexity, OperatorIndex, Precision, Shape, + dot_kind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, }; use crate::dag::rewrite::round::expand_round; -use crate::dag::unparametrized; +use crate::dag::unparametrized::Dag; use crate::noise_estimator::error; use crate::noise_estimator::p_error::{combine_errors, repeat_p_error}; use crate::optimization::config::NoiseBoundConfig; use crate::utils::square; +use dot_kind::DotKind; use std::collections::{HashMap, HashSet}; -// private short convention -use {DotKind as DK, VarianceOrigin as VO}; -type Op = unparametrized::UnparameterizedOperator; - pub fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property { - &properties[inputs[0].i] + &properties[inputs[0].0] } fn assert_all_same( @@ -24,48 +21,42 @@ fn assert_all_same( ) { let first = first(inputs, properties); for input in inputs.iter().skip(1) { - assert_eq!(first, &properties[input.i]); + assert_eq!(first, &properties[input.0]); } } -fn assert_inputs_uniform_precisions( - op: &unparametrized::UnparameterizedOperator, - out_precisions: &[Precision], -) { - if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op { +fn assert_inputs_uniform_precisions(op: &Operator, out_precisions: &[Precision]) { + if let Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } = op { assert_all_same(inputs, out_precisions); } } -fn assert_dot_uniform_inputs_shape( - op: &unparametrized::UnparameterizedOperator, - out_shapes: &[Shape], -) { - if let Op::Dot { inputs, .. } = op { +fn assert_dot_uniform_inputs_shape(op: &Operator, out_shapes: &[Shape]) { + if let Operator::Dot { inputs, .. } = op { assert_all_same(inputs, out_shapes); } } -fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) { - if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op { +fn assert_non_empty_inputs(op: &Operator) { + if let Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } = op { assert!(!inputs.is_empty()); } } -fn assert_inputs_index(op: &unparametrized::UnparameterizedOperator, first_bad_index: usize) { +fn assert_inputs_index(op: &Operator, first_bad_index: usize) { let valid = match op { - Op::Input { .. } => true, - Op::Lut { input, .. } | Op::UnsafeCast { input, .. } | Op::Round { input, .. } => { - input.i < first_bad_index - } - Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { - inputs.iter().all(|input| input.i < first_bad_index) + Operator::Input { .. } => true, + Operator::Lut { input, .. } + | Operator::UnsafeCast { input, .. } + | Operator::Round { input, .. } => input.0 < first_bad_index, + Operator::LevelledOp { inputs, .. } | Operator::Dot { inputs, .. } => { + inputs.iter().all(|input| input.0 < first_bad_index) } }; assert!(valid, "Invalid dag, bad index in op: {op:?}"); } -fn assert_dag_correctness(dag: &unparametrized::OperationDag) { +fn assert_dag_correctness(dag: &Dag) { for (i, op) in dag.operators.iter().enumerate() { assert_non_empty_inputs(op); assert_inputs_uniform_precisions(op, &dag.out_precisions); @@ -74,29 +65,29 @@ fn assert_dag_correctness(dag: &unparametrized::OperationDag) { } } -pub fn has_round(dag: &unparametrized::OperationDag) -> bool { +pub fn has_round(dag: &Dag) -> bool { for op in &dag.operators { - if matches!(op, Op::Round { .. }) { + if matches!(op, Operator::Round { .. }) { return true; } } false } -pub fn has_unsafe_cast(dag: &unparametrized::OperationDag) -> bool { +pub fn has_unsafe_cast(dag: &Dag) -> bool { for op in &dag.operators { - if matches!(op, Op::UnsafeCast { .. }) { + if matches!(op, Operator::UnsafeCast { .. }) { return true; } } false } -pub fn assert_no_round(dag: &unparametrized::OperationDag) { +pub fn assert_no_round(dag: &Dag) { assert!(!has_round(dag)); } -fn assert_valid_variances(dag: &OperationDag) { +fn assert_valid_variances(dag: &SoloKeyDag) { for &out_variance in &dag.out_variances { assert!( SymbolicVariance::ZERO == out_variance // Special case of multiply by 0 @@ -106,24 +97,24 @@ fn assert_valid_variances(dag: &OperationDag) { } } -fn assert_properties_correctness(dag: &OperationDag) { +fn assert_properties_correctness(dag: &SoloKeyDag) { assert_valid_variances(dag); } fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance]) -> VarianceOrigin { let first_origin = first(inputs, out_variances).origin(); for input in inputs.iter().skip(1) { - let item = &out_variances[input.i]; + let item = &out_variances[input.0]; if first_origin != item.origin() { - return VO::Mixed; + return VarianceOrigin::Mixed; } } first_origin } #[derive(Clone, Debug)] -pub struct OperationDag { - pub operators: Vec, +pub struct SoloKeyDag { + pub operators: Vec, // Collect all operators output variances pub out_variances: Vec, pub nb_luts: u64, @@ -148,55 +139,55 @@ pub struct VariancesAndBound { } fn out_variance( - op: &unparametrized::UnparameterizedOperator, + op: &Operator, out_shapes: &[Shape], out_variances: &[SymbolicVariance], ) -> SymbolicVariance { // Maintain a linear combination of input_variance and lut_out_variance // TODO: track each elements instead of container match op { - Op::Input { .. } => SymbolicVariance::INPUT, - Op::Lut { .. } => SymbolicVariance::LUT, - Op::LevelledOp { inputs, manp, .. } => { + Operator::Input { .. } => SymbolicVariance::INPUT, + Operator::Lut { .. } => SymbolicVariance::LUT, + Operator::LevelledOp { inputs, manp, .. } => { let variance_factor = SymbolicVariance::manp_to_variance_factor(*manp); let origin = match variance_origin(inputs, out_variances) { - VO::Input => SymbolicVariance::INPUT, - VO::Lut | VO::Mixed /* Mixed: assume the worst */ + VarianceOrigin::Input => SymbolicVariance::INPUT, + VarianceOrigin::Lut | VarianceOrigin::Mixed /* Mixed: assume the worst */ => SymbolicVariance::LUT }; origin * variance_factor } - Op::Dot { + Operator::Dot { inputs, weights, .. } => { let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor | DK::Broadcast { .. } => { + DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. } => { let first_input = inputs[0]; let mut out_variance = SymbolicVariance::ZERO; for (j, &weight) in weights.values.iter().enumerate() { let k = if inputs.len() > 1 { - inputs[j].i + inputs[j].0 } else { - first_input.i + first_input.0 }; out_variance += out_variances[k] * square(weight as f64); } out_variance } - DK::CompatibleTensor { .. } => todo!("TODO"), - DK::Unsupported { .. } => panic!("Unsupported"), + DotKind::CompatibleTensor { .. } => todo!("TODO"), + DotKind::Unsupported { .. } => panic!("Unsupported"), } } - Op::UnsafeCast { input, .. } => out_variances[input.i], - Op::Round { .. } => { + Operator::UnsafeCast { input, .. } => out_variances[input.0], + Operator::Round { .. } => { unreachable!("Round should have been either expanded or integrated to a lut") } } } -pub fn out_variances(dag: &unparametrized::OperationDag) -> Vec { +pub fn out_variances(dag: &Dag) -> Vec { let nb_ops = dag.operators.len(); let mut out_variances = Vec::with_capacity(nb_ops); for op in &dag.operators { @@ -207,33 +198,27 @@ pub fn out_variances(dag: &unparametrized::OperationDag) -> Vec Vec<(Precision, Shape, SymbolicVariance)> { - dag.get_output_index_iter() - .map(|i| { - ( - dag.out_precisions[i], - dag.out_shapes[i].clone(), - out_variances[i], - ) - }) + dag.get_output_operators_iter() + .map(|op| (*op.precision, op.shape.to_owned(), out_variances[op.id.0])) .collect() } fn in_luts_variance( - dag: &unparametrized::OperationDag, + dag: &Dag, out_variances: &[SymbolicVariance], ) -> Vec<(Precision, Shape, SymbolicVariance)> { dag.operators .iter() .enumerate() .filter_map(|(i, op)| { - if let &Op::Lut { input, .. } = op { + if let &Operator::Lut { input, .. } = op { Some(( - dag.out_precisions[input.i], + dag.out_precisions[input.0], dag.out_shapes[i].clone(), - out_variances[input.i], + out_variances[input.0], )) } else { None @@ -242,32 +227,34 @@ fn in_luts_variance( .collect() } -fn op_levelled_complexity( - op: &unparametrized::UnparameterizedOperator, - out_shapes: &[Shape], -) -> LevelledComplexity { +fn op_levelled_complexity(op: &Operator, out_shapes: &[Shape]) -> LevelledComplexity { match op { - Op::Dot { + Operator::Dot { inputs, weights, .. } => { let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor | DK::Broadcast { .. } | DK::CompatibleTensor => { + DotKind::Simple + | DotKind::Tensor + | DotKind::Broadcast { .. } + | DotKind::CompatibleTensor => { LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size() } - DK::Unsupported { .. } => panic!("Unsupported"), + DotKind::Unsupported { .. } => panic!("Unsupported"), } } - Op::LevelledOp { complexity, .. } => *complexity, - Op::Input { .. } | Op::Lut { .. } | Op::UnsafeCast { .. } => LevelledComplexity::ZERO, - Op::Round { .. } => { + Operator::LevelledOp { complexity, .. } => *complexity, + Operator::Input { .. } | Operator::Lut { .. } | Operator::UnsafeCast { .. } => { + LevelledComplexity::ZERO + } + Operator::Round { .. } => { unreachable!("Round should have been either expanded or integrated to a lut") } } } -pub fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity { +pub fn levelled_complexity(dag: &Dag) -> LevelledComplexity { let mut levelled_complexity = LevelledComplexity::ZERO; for op in &dag.operators { levelled_complexity += op_levelled_complexity(op, &dag.out_shapes); @@ -275,12 +262,12 @@ pub fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComple levelled_complexity } -pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 { +pub fn lut_count_from_dag(dag: &Dag) -> u64 { let mut count = 0; for (i, op) in dag.operators.iter().enumerate() { - if let Op::Lut { .. } = op { + if let Operator::Lut { .. } = op { count += dag.out_shapes[i].flat_size(); - } else if let Op::Round { out_precision, .. } = op { + } else if let Operator::Round { out_precision, .. } = op { count += dag.out_shapes[i].flat_size() * (dag.out_precisions[i] - out_precision) as u64; } } @@ -378,7 +365,7 @@ fn constraint_for_one_precision( } } -pub fn worst_log_norm_for_wop(dag: &unparametrized::OperationDag) -> f64 { +pub fn worst_log_norm_for_wop(dag: &Dag) -> f64 { assert_dag_correctness(dag); assert_no_round(dag); let out_variances = out_variances(dag); @@ -393,10 +380,7 @@ pub fn worst_log_norm_for_wop(dag: &unparametrized::OperationDag) -> f64 { worst.log2() } -pub fn analyze( - dag: &unparametrized::OperationDag, - noise_config: &NoiseBoundConfig, -) -> OperationDag { +pub fn analyze(dag: &Dag, noise_config: &NoiseBoundConfig) -> SoloKeyDag { assert_dag_correctness(dag); let dag = &expand_round(dag); assert_no_round(dag); @@ -411,7 +395,7 @@ pub fn analyze( &in_luts_variance, noise_config, ); - let result = OperationDag { + let result = SoloKeyDag { operators: dag.operators.clone(), out_variances, nb_luts, @@ -463,7 +447,7 @@ fn peak_variance_per_constraint( // Compute the maximum attained relative variance for the full dag fn peak_relative_variance( - dag: &OperationDag, + dag: &SoloKeyDag, input_noise_out: f64, blind_rotate_noise_out: f64, noise_keyswitch: f64, @@ -526,7 +510,7 @@ fn p_error_per_constraint( p_error } -impl OperationDag { +impl SoloKeyDag { pub fn peek_p_error( &self, input_noise_out: f64, @@ -610,14 +594,13 @@ pub mod tests { use super::*; use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights}; - use crate::dag::unparametrized; use crate::utils::square; fn assert_f64_eq(v: f64, expected: f64) { approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON); } - impl OperationDag { + impl SoloKeyDag { pub fn constraint(&self) -> VariancesAndBound { assert!(!self.constraints_by_precisions.is_empty()); assert_eq!(self.constraints_by_precisions.len(), 1); @@ -633,24 +616,23 @@ pub mod tests { maximum_acceptable_error_probability: _4_SIGMA, }; - fn analyze(dag: &unparametrized::OperationDag) -> super::OperationDag { + fn analyze(dag: &Dag) -> super::SoloKeyDag { super::analyze(dag, &CONFIG) } #[test] fn test_1_input() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let input1 = graph.add_input(1, Shape::number()); - graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); - assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT); - assert_eq!(graph.out_shapes[input1.i], Shape::number()); + assert_eq!(analysis.out_variances[input1.0], SymbolicVariance::INPUT); + assert_eq!(graph.out_shapes[input1.0], Shape::number()); assert_eq!(analysis.levelled_complexity, LevelledComplexity::ZERO); - assert_eq!(graph.out_precisions[input1.i], 1); + assert_eq!(graph.out_precisions[input1.0], 1); assert_f64_eq(complexity_cost, 0.0); assert!(analysis.nb_luts == 0); let constraint = analysis.constraint(); @@ -662,19 +644,18 @@ pub mod tests { #[test] fn test_1_lut() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let input1 = graph.add_input(8, Shape::number()); let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 8); - graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); - assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT); - assert!(graph.out_shapes[lut1.i] == Shape::number()); + assert!(analysis.out_variances[lut1.0] == SymbolicVariance::LUT); + assert!(graph.out_shapes[lut1.0] == Shape::number()); assert!(analysis.levelled_complexity == LevelledComplexity::ZERO); - assert_eq!(graph.out_precisions[lut1.i], 8); + assert_eq!(graph.out_precisions[lut1.0], 8); assert_f64_eq(one_lut_cost, complexity_cost); let constraint = analysis.constraint(); assert!(constraint.pareto_output.len() == 1); @@ -687,12 +668,11 @@ pub mod tests { #[test] fn test_1_dot() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let input1 = graph.add_input(1, Shape::number()); let weights = Weights::vector([1, 2]); let norm2: f64 = 1.0 * 1.0 + 2.0 * 2.0; let dot = graph.add_dot([input1, input1], weights); - graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -702,10 +682,10 @@ pub mod tests { input_coeff: norm2, lut_coeff: 0.0, }; - assert!(analysis.out_variances[dot.i] == expected_var); - assert!(graph.out_shapes[dot.i] == Shape::number()); + assert!(analysis.out_variances[dot.0] == expected_var); + assert!(graph.out_shapes[dot.0] == Shape::number()); assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 2); - assert_eq!(graph.out_precisions[dot.i], 1); + assert_eq!(graph.out_precisions[dot.0], 1); let expected_dot_cost = (2 * lwe_dim) as f64; assert_f64_eq(expected_dot_cost, complexity_cost); let constraint = analysis.constraint(); @@ -717,23 +697,22 @@ pub mod tests { #[test] fn test_1_dot_levelled() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let input1 = graph.add_input(3, Shape::number()); let cpx_dot = LevelledComplexity::ADDITION; let weights = Weights::vector([1, 2]); #[allow(clippy::imprecise_flops)] let manp = (1.0 * 1.0 + 2.0 * 2_f64).sqrt(); let dot = graph.add_levelled_op([input1, input1], cpx_dot, manp, Shape::number(), "dot"); - graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); - assert!(analysis.out_variances[dot.i].origin() == VO::Input); - assert_eq!(graph.out_precisions[dot.i], 3); + assert!(analysis.out_variances[dot.0].origin() == VarianceOrigin::Input); + assert_eq!(graph.out_precisions[dot.0], 3); let expected_square_norm2 = weights.square_norm2() as f64; - let actual_square_norm2 = analysis.out_variances[dot.i].input_coeff; + let actual_square_norm2 = analysis.out_variances[dot.0].input_coeff; // Due to call on log2() to compute manp the result is not exact assert_f64_eq(actual_square_norm2, expected_square_norm2); assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION); @@ -741,20 +720,19 @@ pub mod tests { let constraint = analysis.constraint(); assert!(constraint.pareto_in_lut.is_empty()); assert!(constraint.pareto_output.len() == 1); - assert_eq!(constraint.pareto_output[0].origin(), VO::Input); + assert_eq!(constraint.pareto_output[0].origin(), VarianceOrigin::Input); assert_f64_eq(constraint.pareto_output[0].input_coeff, 5.0); } #[test] fn test_dot_tensorized_lut_dot_lut() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let input1 = graph.add_input(1, Shape::vector(2)); let weights = &Weights::vector([1, 2]); let dot1 = graph.add_dot([input1], weights); let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1); let dot2 = graph.add_dot([lut1, lut1], weights); let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN, 1); - graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -776,19 +754,19 @@ pub mod tests { input_coeff: 0.0, lut_coeff: 1.0, }; - assert!(analysis.out_variances[dot1.i] == expected_var_dot1); - assert!(analysis.out_variances[lut1.i] == expected_var_lut1); - assert!(analysis.out_variances[dot2.i] == expected_var_dot2); - assert!(analysis.out_variances[lut2.i] == expected_var_lut2); + assert!(analysis.out_variances[dot1.0] == expected_var_dot1); + assert!(analysis.out_variances[lut1.0] == expected_var_lut1); + assert!(analysis.out_variances[dot2.0] == expected_var_dot2); + assert!(analysis.out_variances[lut2.0] == expected_var_lut2); assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 4); let expected_cost = (lwe_dim * 4) as f64 + 2.0 * one_lut_cost; assert_f64_eq(expected_cost, complexity_cost); let constraint = analysis.constraint(); assert_eq!(constraint.pareto_output.len(), 1); - assert_eq!(constraint.pareto_output[0].origin(), VO::Lut); + assert_eq!(constraint.pareto_output[0].origin(), VarianceOrigin::Lut); assert_f64_eq(constraint.pareto_output[0].lut_coeff, 1.0); assert_eq!(constraint.pareto_in_lut.len(), 1); - assert_eq!(constraint.pareto_in_lut[0].origin(), VO::Lut); + assert_eq!(constraint.pareto_in_lut[0].origin(), VarianceOrigin::Lut); assert_f64_eq( constraint.pareto_in_lut[0].lut_coeff, weights.square_norm2() as f64, @@ -797,13 +775,12 @@ pub mod tests { #[test] fn test_lut_dot_mixed_lut() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let input1 = graph.add_input(1, Shape::number()); let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1); let weights = &Weights::vector([2, 3]); let dot1 = graph.add_dot([input1, lut1], weights); let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1); - graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -819,18 +796,17 @@ pub mod tests { assert_eq!(constraint.pareto_output.len(), 1); assert_eq!(constraint.pareto_output[0], SymbolicVariance::LUT); assert_eq!(constraint.pareto_in_lut.len(), 1); - assert_eq!(constraint.pareto_in_lut[0].origin(), VO::Mixed); + assert_eq!(constraint.pareto_in_lut[0].origin(), VarianceOrigin::Mixed); assert_eq!(constraint.pareto_in_lut[0], expected_mixed); } #[test] fn test_multi_precision_input() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let max_precision: Precision = 5; for i in 1..=max_precision { _ = graph.add_input(i, Shape::number()); } - graph.detect_outputs(); let analysis = analyze(&graph); assert!(analysis.constraints_by_precisions.len() == max_precision as usize); let mut prev_safe_noise_bound = 0.0; @@ -845,13 +821,12 @@ pub mod tests { #[test] fn test_multi_precision_lut() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let max_precision: Precision = 5; for p in 1..=max_precision { let input = graph.add_input(p, Shape::number()); let _lut = graph.add_lut(input, FunctionTable::UNKWOWN, p); } - graph.detect_outputs(); let analysis = analyze(&graph); assert!(analysis.constraints_by_precisions.len() == max_precision as usize); let mut prev_safe_noise_bound = 0.0; @@ -871,7 +846,7 @@ pub mod tests { #[test] fn test_broadcast_dot_multiply_by_number() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let shape = Shape { dimensions_size: vec![2, 2], }; @@ -879,14 +854,13 @@ pub mod tests { let weights = &Weights::number(2); _ = graph.add_dot([input1], weights); assert!(*graph.out_shapes.last().unwrap() == shape); - graph.detect_outputs(); let analysis = analyze(&graph); assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0); } #[test] fn test_broadcast_dot_add_tensor() { - let mut graph = unparametrized::OperationDag::new(); + let mut graph = Dag::new(); let shape = Shape { dimensions_size: vec![2, 2], }; @@ -896,7 +870,6 @@ pub mod tests { let weights = &Weights::vector([2, 3]); _ = graph.add_dot([input1, lut2], weights); assert!(*graph.out_shapes.last().unwrap() == shape); - graph.detect_outputs(); let analysis = analyze(&graph); assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0); assert_f64_eq(analysis.out_variances.last().unwrap().lut_coeff, 9.0); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 8f92a2c284..aff178f6ef 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -4,7 +4,7 @@ use concrete_security_curves::gaussian::security::minimal_variance_lwe; use super::analyze; use crate::dag::operator::{LevelledComplexity, Precision}; use crate::dag::unparametrized; -use crate::dag::unparametrized::OperationDag; +use crate::dag::unparametrized::Dag; use crate::noise_estimator::error; use crate::optimization::atomic_pattern::{ OptimizationDecompositionsConsts, OptimizationState, Solution, @@ -23,7 +23,7 @@ use crate::parameters::GlweParameters; fn update_best_solution_with_best_decompositions( state: &mut OptimizationState, consts: &OptimizationDecompositionsConsts, - dag: &analyze::OperationDag, + dag: &analyze::SoloKeyDag, internal_dim: u64, glwe_params: GlweParameters, input_noise_out: f64, @@ -137,7 +137,7 @@ const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8; fn update_no_luts_solution( state: &mut OptimizationState, consts: &OptimizationDecompositionsConsts, - dag: &analyze::OperationDag, + dag: &analyze::SoloKeyDag, input_lwe_dimension: u64, input_noise_out: f64, ) { @@ -203,7 +203,7 @@ fn minimal_variance(config: &Config, glwe_params: GlweParameters) -> f64 { fn optimize_no_luts( mut state: OptimizationState, consts: &OptimizationDecompositionsConsts, - dag: &analyze::OperationDag, + dag: &analyze::SoloKeyDag, search_space: &SearchSpace, ) -> OptimizationState { let not_feasible = |input_noise_out| !dag.feasible(input_noise_out, 0.0, 0.0, 0.0); @@ -221,7 +221,7 @@ fn optimize_no_luts( } fn not_feasible_macro_parameters( - dag: &analyze::OperationDag, + dag: &analyze::SoloKeyDag, internal_dim: u64, input_noise_out: f64, noise_modulus_switching: f64, @@ -240,7 +240,7 @@ fn not_feasible_macro_parameters( fn too_complex_macro_parameters( state: &OptimizationState, - dag: &analyze::OperationDag, + dag: &analyze::SoloKeyDag, internal_dim: u64, glwe_params: GlweParameters, cmux_pareto: &[CmuxComplexityNoise], @@ -261,7 +261,7 @@ fn too_complex_macro_parameters( #[allow(clippy::too_many_lines)] pub fn optimize( - dag: &unparametrized::OperationDag, + dag: &unparametrized::Dag, config: Config, search_space: &SearchSpace, persistent_caches: &PersistDecompCaches, @@ -383,7 +383,7 @@ pub fn optimize( state } -pub fn add_v0_dag(dag: &mut OperationDag, sum_size: u64, precision: u64, noise_factor: f64) { +pub fn add_v0_dag(dag: &mut Dag, sum_size: u64, precision: u64, noise_factor: f64) { use crate::dag::operator::{FunctionTable, Shape}; let same_scale_manp = 1.0; let manp = noise_factor; @@ -398,8 +398,8 @@ pub fn add_v0_dag(dag: &mut OperationDag, sum_size: u64, precision: u64, noise_f let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); } -pub fn v0_dag(sum_size: u64, precision: u64, noise_factor: f64) -> OperationDag { - let mut dag = unparametrized::OperationDag::new(); +pub fn v0_dag(sum_size: u64, precision: u64, noise_factor: f64) -> Dag { + let mut dag = unparametrized::Dag::new(); add_v0_dag(&mut dag, sum_size, precision, noise_factor); dag } @@ -472,7 +472,7 @@ pub(crate) mod tests { ) }); - pub fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState { + pub fn optimize(dag: &unparametrized::Dag) -> OptimizationState { let config = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, @@ -589,14 +589,13 @@ pub(crate) mod tests { let processing_unit = config::ProcessingUnit::Cpu; let security_level = 128; - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); { let input1 = dag.add_input(precision, Shape::number()); let dot1 = dag.add_dot([input1], [1]); let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); let dot2 = dag.add_dot([lut1], [weight]); let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); - dag.detect_outputs(); } { let dag2 = analyze::analyze( @@ -655,11 +654,11 @@ pub(crate) mod tests { } fn no_lut_vs_lut(precision: Precision) { - let mut dag_lut = unparametrized::OperationDag::new(); + let mut dag_lut = unparametrized::Dag::new(); let input1 = dag_lut.add_input(precision, Shape::number()); let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision); - let mut dag_no_lut = unparametrized::OperationDag::new(); + let mut dag_no_lut = unparametrized::Dag::new(); let _input2 = dag_no_lut.add_input(precision, Shape::number()); let state_no_lut = optimize(&dag_no_lut); @@ -691,7 +690,7 @@ pub(crate) mod tests { ) { let weight = &Weights::number(weight); - let mut dag_1 = unparametrized::OperationDag::new(); + let mut dag_1 = unparametrized::Dag::new(); { let input1 = dag_1.add_input(precision, Shape::number()); let scaled_input1 = dag_1.add_dot([input1], weight); @@ -699,7 +698,7 @@ pub(crate) mod tests { let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN, precision); } - let mut dag_2 = unparametrized::OperationDag::new(); + let mut dag_2 = unparametrized::Dag::new(); { let input1 = dag_2.add_input(precision, Shape::number()); let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN, precision); @@ -731,14 +730,14 @@ pub(crate) mod tests { fn lut_1_layer_has_better_complexity(precision: Precision) { let dag_1_layer = { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(precision, Shape::number()); let _lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision); let _lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision); dag }; let dag_2_layer = { - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(precision, Shape::number()); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision); let _lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision); @@ -765,7 +764,7 @@ pub(crate) mod tests { } } - fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: i64) { + fn circuit(dag: &mut unparametrized::Dag, precision: Precision, weight: i64) { let input = dag.add_input(precision, Shape::number()); let dot1 = dag.add_dot([input], [weight]); let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); @@ -776,9 +775,9 @@ pub(crate) mod tests { fn assert_multi_precision_dominate_single(weight: i64) -> Option { let low_precision = 4u8; let high_precision = 5u8; - let mut dag_low = unparametrized::OperationDag::new(); - let mut dag_high = unparametrized::OperationDag::new(); - let mut dag_multi = unparametrized::OperationDag::new(); + let mut dag_low = unparametrized::Dag::new(); + let mut dag_high = unparametrized::Dag::new(); + let mut dag_multi = unparametrized::Dag::new(); { circuit(&mut dag_low, low_precision, weight); @@ -855,7 +854,7 @@ pub(crate) mod tests { fn check_global_p_error_input(dim: u64, weight: i64, precision: u8) -> f64 { let shape = Shape::vector(dim); let weights = Weights::number(weight); - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(precision, shape); let _dot1 = dag.add_dot([input1], weights); // this is just several multiply let state = optimize(&dag); @@ -879,7 +878,7 @@ pub(crate) mod tests { fn check_global_p_error_lut(depth: u64, weight: i64, precision: u8) { let shape = Shape::number(); let weights = Weights::number(weight); - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let mut last_val = dag.add_input(precision, shape); for _i in 0..depth { let dot = dag.add_dot([last_val], &weights); @@ -903,9 +902,9 @@ pub(crate) mod tests { precision_high: Precision, weight_low: i64, weight_high: i64, - ) -> unparametrized::OperationDag { + ) -> unparametrized::Dag { let shape = Shape::number(); - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let weights_low = Weights::number(weight_low); let weights_high = Weights::number(weight_high); let mut last_val_low = dag.add_input(precision_low, &shape); @@ -977,10 +976,10 @@ pub(crate) mod tests { rounded_precision: Precision, precision: Precision, weight: i64, - ) -> unparametrized::OperationDag { + ) -> unparametrized::Dag { // circuit with intermediate high precision in levelled op let shape = Shape::number(); - let mut dag = unparametrized::OperationDag::new(); + let mut dag = unparametrized::Dag::new(); let weight = Weights::number(weight); let val = dag.add_input(precision, shape); let lut1 = dag.add_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs index 283656b955..36d08681cf 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs @@ -1,5 +1,5 @@ use crate::dag::operator::Precision; -use crate::dag::unparametrized::OperationDag; +use crate::dag::unparametrized::Dag; use crate::noise_estimator::p_error::repeat_p_error; use crate::optimization::atomic_pattern::Solution as WpSolution; use crate::optimization::config::{Config, SearchSpace}; @@ -31,7 +31,7 @@ pub enum Encoding { Crt, } -pub fn max_precision(dag: &OperationDag) -> Precision { +pub fn max_precision(dag: &Dag) -> Precision { dag.out_precisions.iter().copied().max().unwrap_or(0) } @@ -63,7 +63,7 @@ fn best_complexity_solution(native: Option, crt: Option) -> } fn optimize_with_wop_pbs( - dag: &OperationDag, + dag: &Dag, config: Config, search_space: &SearchSpace, default_log_norm2_woppbs: f64, @@ -82,7 +82,7 @@ fn optimize_with_wop_pbs( } pub fn optimize( - dag: &OperationDag, + dag: &Dag, config: Config, search_space: &SearchSpace, encoding: Encoding, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs index 6d8ec019b5..b84b89b725 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs @@ -1,6 +1,7 @@ pub mod cache; pub mod f64; pub mod hasher_builder; +pub mod viz; pub fn square(v: V) -> V where diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs new file mode 100644 index 0000000000..335ef3690c --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs @@ -0,0 +1,174 @@ +use crate::dag::operator::Operator; + +/// A trait allowing to visualize objects as graphviz/dot graphs. +/// +/// Useful to use along with the [`viz`] and [`vizp`] macros to debug objects. +pub trait Viz { + /// This method must return a string on the `dot` format, containing the description of the + /// object. + fn viz_node(&self) -> String; + + /// This method can be re-implemented if the object must be referenced by other objects `Viz` + /// implementation (to create edges mostly). + fn viz_label(&self) -> String { + String::new() + } + + /// Returns a string containing a valid dot representation of the object. + fn viz_string(&self) -> String { + format!( + " +strict digraph G {{ +fontname=\"Helvetica,Arial,sans-serif\" +node [fontname=\"Helvetica,Arial,sans-serif\" gradientangle=270] +edge [fontname=\"Helvetica,Arial,sans-serif\"] +rankdir=TB; +node [shape=record style=filled]; +{} +}}", + self.viz_node() + ) + } +} + +impl Viz for crate::dag::unparametrized::Dag { + fn viz_node(&self) -> String { + let mut graph = vec![]; + self.get_circuits_iter() + .for_each(|circuit| graph.push(circuit.viz_node())); + graph.join("\n") + } +} + +impl<'dag> Viz for crate::dag::unparametrized::DagCircuit<'dag> { + fn viz_node(&self) -> String { + let mut graph: Vec = vec![]; + let circuit = &self.circuit; + self.get_operators_iter().for_each(|node| { + graph.push(node.viz_node()); + node.get_inputs_iter().for_each(|inp_node| { + let inp_label = inp_node.viz_label(); + let oup_label = node.viz_label(); + graph.push(format!("{inp_label} -> {oup_label} [weight=10];")); + }); + }); + format!( + " +subgraph cluster_circuit_{circuit} {{ +label=\"Circuit {circuit}\" +style=\"rounded\" +{} +}} +", + graph.join("\n") + ) + } +} + +impl<'dag> Viz for crate::dag::unparametrized::DagOperator<'dag> { + fn viz_node(&self) -> String { + let input_string = self + .operator + .get_inputs_iter() + .map(|id| id.0) + .map(|n| format!("%{n}")) + .collect::>() + .join(", "); + let index = self.id; + let color = if self.is_input() || self.is_output() { + "lightseagreen" + } else { + "lightgreen" + }; + match self.operator { + Operator::Input { out_precision, .. } => { + format!("{index} [label =\"{{%{index} = Input({input_string}) |{{out_precision:|{out_precision:?}}}}}\" fillcolor={color}];") + } + Operator::Lut { out_precision, .. } => { + format!("{index} [label = \"{{%{index} = Lut({input_string}) |{{out_precision:|{out_precision:?}}}}}\" fillcolor={color}];") + } + Operator::Dot { .. } => { + format!("{index} [label = \"{{%{index} = Dot({input_string})}}\" fillcolor={color}];") + } + Operator::LevelledOp { manp, .. } => { + format!("{index} [label = \"{{%{index} = LevelledOp({input_string}) |{{manp:|{manp:?}}}}}\" fillcolor={color}];") + } + Operator::UnsafeCast { out_precision, .. } => format!( + "{index} [label = \"{{%{index} = UnsafeCast({input_string}) |{{out_precision:|{out_precision:?}}}}}\" fillcolor={color}];" + ), + Operator::Round { out_precision, .. } => { + format!("{index} [label = \"{{%{index} = Round({input_string}) |{{out_precision:|{out_precision:?}}}}}\" fillcolor={color}];",) + } + } + } + + fn viz_label(&self) -> String { + let index = self.id; + format!("{index}") + } +} + +macro_rules! _viz { + ($path: expr, $object:ident) => {{ + let mut path = std::env::temp_dir(); + path.push($path); + let _ = std::process::Command::new("sh") + .arg("-c") + .arg(format!( + "echo '{}' | dot -Tsvg > {}", + $crate::utils::viz::Viz::viz_string(&$object), + path.to_str().unwrap() + )) + .output() + .expect("Failed to execute dot. Do you have graphviz installed ?"); + }}; +} + +/// Dumps the visualization of an object to a given svg file. +#[allow(unused)] +macro_rules! viz { + ($path: expr, $object:ident) => { + $crate::utils::viz::_viz!($path, $object); + println!( + "Viz of {}:{} visible at {}/{}", + file!(), + line!(), + std::env::temp_dir().display(), + $path + ); + }; + ($object:ident) => { + $crate::utils::viz::viz!( + format!("concrete_optimizer_dbg_{}.svg", rand::random::()), + $object + ); + }; +} + +/// Dumps the visualization of an object to a given svg file and panics. +#[allow(unused)] +macro_rules! vizp { + ($path: expr, $object:ident) => {{ + $crate::utils::viz::_viz!($path, $object); + panic!( + "Viz of {}:{} visible at {}/{}", + file!(), + line!(), + std::env::temp_dir().display(), + $path + ); + }}; + ($object:ident) => { + $crate::utils::viz::vizp!( + format!("concrete_optimizer_dbg_{}.svg", rand::random::()), + $object + ); + }; +} + +#[allow(unused)] +pub(crate) use _viz; +#[allow(unused)] +pub(crate) use viz; +#[allow(unused)] +pub(crate) use vizp;