Skip to content

Commit

Permalink
feat(optimizer): allow circuit manipulation in optimizer dag
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed Apr 29, 2024
1 parent 5c5f573 commit ac44865
Show file tree
Hide file tree
Showing 33 changed files with 1,437 additions and 1,135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ namespace mlir {
namespace concretelang {

namespace optimizer {
using FunctionsDag = std::map<std::string, std::optional<Dag>>;

std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
optimizer::FunctionsDag &dags);
concrete_optimizer::Dag &dag);

} // namespace optimizer
} // namespace concretelang
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mlir::LogicalResult materializeOptimizerPartitionFrontiers(
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);

llvm::Expected<std::map<std::string, std::optional<optimizer::Description>>>
llvm::Expected<std::optional<optimizer::Description>>
getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
optimizer::Config config,
std::function<bool(mlir::Pass *)> enablePass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ constexpr Config DEFAULT_CONFIG = {
DEFAULT_COMPOSABLE,
};

using Dag = rust::Box<concrete_optimizer::OperationDag>;
using Dag = rust::Box<concrete_optimizer::Dag>;
using DagBuilder = rust::Box<concrete_optimizer::DagBuilder>;
using DagSolution = concrete_optimizer::dag::DagSolution;
using CircuitSolution = concrete_optimizer::dag::CircuitSolution;

Expand Down

Large diffs are not rendered by default.

48 changes: 14 additions & 34 deletions compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <err.h>
#include <fstream>
#include <iostream>
#include <iterator>
#include <llvm/Support/Debug.h>
#include <memory>
#include <mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h>
Expand Down Expand Up @@ -161,54 +162,33 @@ llvm::Expected<std::optional<optimizer::Description>>
CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
mlir::ModuleOp module = res.mlirModuleRef->get();
auto funcs = module.getOps<func::FuncOp>();
auto nFuncs = std::distance(funcs.begin(), funcs.end());
auto config = this->compilerOptions.optimizerConfig;
if (nFuncs > 1 &&
config.strategy !=
mlir::concretelang::optimizer::V0) { // Multi circuits without V0
return StreamStringError(
"Multi-circuits is only supported for V0 optimization.");
}
// If the values has been overwritten returns
if (this->overrideMaxEintPrecision.has_value() &&
this->overrideMaxMANP.has_value()) {
auto constraint = mlir::concretelang::V0FHEConstraint{
this->overrideMaxMANP.value(), this->overrideMaxEintPrecision.value()};
return optimizer::Description{constraint, std::nullopt};
}
auto config = this->compilerOptions.optimizerConfig;
auto descriptions = mlir::concretelang::pipeline::getFHEContextFromFHE(
auto description = mlir::concretelang::pipeline::getFHEContextFromFHE(
mlirContext, module, config, enablePass);
if (auto err = descriptions.takeError()) {
if (auto err = description.takeError()) {
return std::move(err);
}
if (descriptions->empty()) { // The pass has not been run
if (!description->has_value()) { // The pass has not been run
return std::nullopt;
}
if (descriptions->size() > 1 &&
config.strategy !=
mlir::concretelang::optimizer::V0) { // Multi circuits without V0
return StreamStringError(
"Multi-circuits is only supported for V0 optimization.");
}
if (descriptions->size() > 1) {
auto iter = descriptions->begin();
auto desc = std::move(iter->second);
if (!desc.has_value()) {
return StreamStringError("Expected description.");
}
if (!desc.value().dag.has_value()) {
return StreamStringError("Expected dag in description.");
}
iter++;
while (iter != descriptions->end()) {
if (!iter->second.has_value()) {
return StreamStringError("Expected description.");
}
if (!iter->second.value().dag.has_value()) {
return StreamStringError("Expected dag in description.");
}
desc->dag.value()->concat(*iter->second.value().dag.value());
iter++;
}
return std::move(desc);
}
return std::move(descriptions->begin()->second);
return description;
}

/// set the fheContext field if the v0Constraint can be computed
/// set the fheContext field if the v0Constraint can be computed
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
if (compilerOptions.v0Parameter.has_value()) {
Expand Down
28 changes: 11 additions & 17 deletions compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "llvm/Support/TargetSelect.h"

#include "concrete-optimizer.hpp"
#include "concretelang/Support/CompilationFeedback.h"
#include "concretelang/Support/V0Parameters.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
Expand All @@ -13,6 +14,7 @@
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Error.h"
#include <optional>

#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
Expand Down Expand Up @@ -89,13 +91,12 @@ addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr<Pass> pass,
}
}

llvm::Expected<std::map<std::string, std::optional<optimizer::Description>>>
llvm::Expected<std::optional<optimizer::Description>>
getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
optimizer::Config config,
std::function<bool(mlir::Pass *)> enablePass) {
std::optional<size_t> oMax2norm;
std::optional<size_t> oMaxWidth;
optimizer::FunctionsDag dags;

mlir::PassManager pm(&context);

Expand Down Expand Up @@ -127,26 +128,19 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
{/*.norm2 = */ ceilLog2(oMax2norm.value()),
/*.p = */ oMaxWidth.value()});
}
addPotentiallyNestedPass(pm, optimizer::createDagPass(config, dags),
auto dag = concrete_optimizer::dag::empty();
addPotentiallyNestedPass(pm, optimizer::createDagPass(config, *dag),
enablePass);
if (pm.run(module.getOperation()).failed()) {
return StreamStringError() << "Failed to create concrete-optimizer dag\n";
}
std::map<std::string, std::optional<optimizer::Description>> descriptions;
for (auto &entry_dag : dags) {
if (!constraint) {
descriptions.insert(
decltype(descriptions)::value_type(entry_dag.first, std::nullopt));
continue;
}
optimizer::Description description = {*constraint,
std::move(entry_dag.second)};
std::optional<optimizer::Description> opt_description{
std::move(description)};
descriptions.insert(decltype(descriptions)::value_type(
entry_dag.first, std::move(opt_description)));
std::optional<optimizer::Description> description;
if (!constraint) {
description = std::nullopt;
} else {
description = {*constraint, std::move(dag)};
}
return std::move(descriptions);
return std::move(description);
}

mlir::LogicalResult materializeOptimizerPartitionFrontiers(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s


//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i64 = arith.constant 1 : i64
Expand All @@ -11,6 +12,7 @@ func.func @add_glwe_const_int(%arg0: !TFHE.glwe<sk[1]<1,1024>>) -> !TFHE.glwe<sk
return %1: !TFHE.glwe<sk[1]<1,1024>>
}

// -----

//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s

//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i64 = arith.constant 1 : i64
Expand All @@ -11,6 +11,7 @@ func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<sk[1]<1,1024>>) -> !TFHE.glwe<sk
return %1: !TFHE.glwe<sk[1]<1,1024>>
}

// -----

//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s

//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i64 = arith.constant 1 : i64
Expand All @@ -12,6 +12,8 @@ func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<sk[1]<1,1024>>) -> !TFHE.glwe<sk
return %1: !TFHE.glwe<sk[1]<1,1024>>
}

// -----

//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64>
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_tensor"(%[[V0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes fhe-boolean-transform --action=dump-fhe %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes fhe-boolean-transform --action=dump-fhe --split-input-file %s 2>&1| FileCheck %s

// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool {
Expand All @@ -15,6 +15,8 @@ func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>)
return %1: !FHE.ebool
}

// -----

// CHECK-LABEL: func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64>
Expand All @@ -31,6 +33,8 @@ func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
return %1: !FHE.ebool
}

// -----

// CHECK-LABEL: func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 1]> : tensor<4xi64>
Expand All @@ -47,6 +51,8 @@ func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
return %1: !FHE.ebool
}

// -----

// CHECK-LABEL: func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[1, 1, 1, 0]> : tensor<4xi64>
Expand All @@ -63,6 +69,8 @@ func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
return %1: !FHE.ebool
}

// -----

// CHECK-LABEL: func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 0]> : tensor<4xi64>
Expand All @@ -79,6 +87,8 @@ func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
return %1: !FHE.ebool
}

// -----

// CHECK-LABEL: func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool
func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT1:.*]] = arith.constant dense<[0, 0, 1, 0]> : tensor<4xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe --split-input-file --skip-program-info %s 2>&1| FileCheck %s


// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe<sk[1]<527,1>>, %arg1: i64) -> !TFHE.glwe<sk[1]<527,1>>
Expand All @@ -10,6 +10,8 @@ func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe<sk[1]<527,1>>, %arg1:
return %1: !TFHE.glwe<sk[1]<527,1>>
}

// -----

// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe<sk[1]<527,1>>) -> !TFHE.glwe<sk[1]<527,1>>
func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe<sk[1]<527,1>>) -> !TFHE.glwe<sk[1]<527,1>> {
// CHECK-NEXT: %[[V1:.*]] = "TFHE.zero"() : () -> !TFHE.glwe<sk[1]<527,1>>
Expand All @@ -20,6 +22,8 @@ func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !TFHE.glwe<sk[1]<527,1>>) -> !T
return %2: !TFHE.glwe<sk[1]<527,1>>
}

// -----

// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !TFHE.glwe<sk[1]<527,1>>) -> !TFHE.glwe<sk[1]<527,1>>
func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !TFHE.glwe<sk[1]<527,1>>) -> !TFHE.glwe<sk[1]<527,1>> {
// CHECK-NEXT: %[[V1:.*]] = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<sk[1]<527,1>>) -> !TFHE.glwe<sk[1]<527,1>>
Expand Down
Loading

0 comments on commit ac44865

Please sign in to comment.