From 6b73879eada42a30dde356bb7990ee580b915da5 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 26 Feb 2024 15:29:27 +0100 Subject: [PATCH 01/17] feat(compiler): Add support for tiling of FHELinalg.apply_lookup_table --- .../TensorOpsToLinalg.cpp | 3 ++ .../Dialect/FHELinalg/Transforms/Tiling.cpp | 8 ++++-- .../check_tests/Dialect/FHELinalg/tiling.mlir | 28 +++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 3dc83a8ca2..1fd921355d 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -697,6 +697,9 @@ struct FHELinalgApplyLookupTableToLinalgGeneric outs, maps, iteratorTypes, doc, call, bodyBuilder); + if (lutOp->hasAttr("tile-sizes")) + genericOp->setAttr("tile-sizes", lutOp->getAttr("tile-sizes")); + rewriter.replaceOp(lutOp, {genericOp.getResult(0)}); return ::mlir::success(); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp index 67cf341388..f2a327dc2b 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp @@ -185,8 +185,12 @@ class FHELinalgTilingMarkerPass mlir::ArrayAttr tileAttr = mlir::Builder(&this->getContext()).getI64ArrayAttr(tileSizes); - op->walk([&](mlir::concretelang::FHELinalg::MatMulEintIntOp matmulOp) { - matmulOp.getOperation()->setAttr("tile-sizes", tileAttr); + op->walk([&](mlir::Operation *op) { + if (llvm::isa( + op) || + llvm::isa(op)) { + op->setAttr("tile-sizes", tileAttr); + } }); } diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir index 8ca00b7244..069b514457 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir @@ -45,3 +45,31 @@ func.func @tiled_one_big_tile(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) %0 = "FHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [0,0,4] } : (tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!FHE.eint<6>> return %0 : tensor<8x2x!FHE.eint<6>> } + +// ----- + +// CHECK: %[[V1:.*]] = scf.forall (%[[Varg2:.*]], %[[Varg3:.*]], %[[Varg4:.*]]) in (1, 1, 2) shared_outs(%[[Varg5:.*]] = %[[V0:.*]]) -> (tensor<2x3x4x!FHE.eint<2>>) { +// CHECK-NEXT: %[[V2:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: %[[V3:.*]] = affine.apply #map1(%[[Varg3]]) +// CHECK-NEXT: %[[V4:.*]] = affine.apply #map(%[[Varg4]]) +// CHECK-NEXT: %[[V5:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: %[[V6:.*]] = affine.apply #map1(%[[Varg3]]) +// CHECK-NEXT: %[[V7:.*]] = affine.apply #map(%[[Varg4]]) +// CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V2]], %[[V3]], %[[V4]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>> +// CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg5]]{{\[}}%[[V5]], %[[V6]], %[[V7]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>> +// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map2{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor<2x3x2x!FHE.eint<2>>) outs(%[[Vextracted_slice_0]] : tensor<2x3x2x!FHE.eint<2>>) attrs = {"tile-sizes" = {{\[2, 3, 2\]}}} { +// CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<2>, %[[Vout:.*]]: !FHE.eint<2>): +// CHECK-NEXT: %[[V12:.*]] = "FHE.apply_lookup_table"(%[[Vin]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> +// CHECK-NEXT: linalg.yield %[[V12]] : !FHE.eint<2> +// CHECK-NEXT: } -> tensor<2x3x2x!FHE.eint<2>> +// CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: %[[V10:.*]] = affine.apply #map1(%[[Varg3]]) +// CHECK-NEXT: %[[V11:.*]] = affine.apply #map(%[[Varg4]]) +// CHECK-NEXT: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %[[V8]] into %[[Varg5]]{{\[}}%[[V9]], %[[V10]], %[[V11]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x2x!FHE.eint<2>> into tensor<2x3x4x!FHE.eint<2>> +// CHECK-NEXT: } +// CHECK-NEXT: } +func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> { + %1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1) { "tile-sizes" = [2,3,2] } : (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!FHE.eint<2>>) + return %1: tensor<2x3x4x!FHE.eint<2>> +} From 5c81882704c3dea448401801002844fb89049b60 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 29 Feb 2024 16:00:50 +0100 Subject: [PATCH 02/17] feat(compiler): Add new action dump-fhe-df-parallelized This adds a new option `dump-fhe-df-parallelized` to `concretecompiler` that dumps the IR after the generation of data-flow tasks. --- .../compiler/include/concretelang/Support/CompilerEngine.h | 4 ++++ .../compiler/lib/Support/CompilerEngine.cpp | 3 +++ compilers/concrete-compiler/compiler/src/main.cpp | 7 +++++++ 3 files changed, 14 insertions(+) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h index 68283185ec..e1fd0772ee 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h @@ -252,6 +252,10 @@ class CompilerEngine { /// from the Linalg dialect FHE_LINALG_GENERIC, + /// Read sources and lower all the FHELinalg operations to FHE + /// operations, dump after data-flow parallelization + FHE_DF_PARALLELIZED, + /// Read sources and lower all the FHELinalg operations to FHE operations /// and scf loops FHE_NO_LINALG, diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index 0c93dc0931..31a7cdd04d 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -459,6 +459,9 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, return StreamStringError("Dataflow parallelization failed"); } + if (target == Target::FHE_DF_PARALLELIZED) + return std::move(res); + if (mlir::concretelang::pipeline::lowerLinalgToLoops( mlirContext, module, enablePass, loopParallelize) .failed()) { diff --git a/compilers/concrete-compiler/compiler/src/main.cpp b/compilers/concrete-compiler/compiler/src/main.cpp index b5cfefdf00..a0abdd06f5 100644 --- a/compilers/concrete-compiler/compiler/src/main.cpp +++ b/compilers/concrete-compiler/compiler/src/main.cpp @@ -51,6 +51,7 @@ namespace optimizer = mlir::concretelang::optimizer; enum Action { ROUND_TRIP, DUMP_FHE, + DUMP_FHE_DF_PARALLELIZED, DUMP_FHE_LINALG_GENERIC, DUMP_FHE_NO_LINALG, DUMP_TFHE, @@ -142,6 +143,9 @@ static llvm::cl::opt action( llvm::cl::values(clEnumValN(Action::DUMP_FHE_LINALG_GENERIC, "dump-fhe-linalg-generic", "Lower FHELinalg to Linalg and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_FHE_DF_PARALLELIZED, + "dump-fhe-df-parallelized", + "Dump result after data-flow parallelization")), llvm::cl::values(clEnumValN(Action::DUMP_FHE_NO_LINALG, "dump-fhe-no-linalg", "Lower FHELinalg to FHE and dump result")), @@ -580,6 +584,9 @@ mlir::LogicalResult processInputBuffer( case Action::DUMP_FHE_LINALG_GENERIC: target = mlir::concretelang::CompilerEngine::Target::FHE_LINALG_GENERIC; break; + case Action::DUMP_FHE_DF_PARALLELIZED: + target = mlir::concretelang::CompilerEngine::Target::FHE_DF_PARALLELIZED; + break; case Action::DUMP_FHE_NO_LINALG: target = mlir::concretelang::CompilerEngine::Target::FHE_NO_LINALG; break; From 6c1cef1fd3ef50fa2c747ff34b3abd450a94317b Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 14 Mar 2024 15:42:10 +0100 Subject: [PATCH 03/17] refactor(compiler): Factor normalization of loop IVs from Batching-specific code This introduces a new function `normalizeInductionVar()` to the static loop utility code in `concretelang/Analysis/StaticLoops.h` with code extracted for IV normalization from the batching code and changes the batching code to make use of the factored function. --- .../concretelang/Analysis/StaticLoops.h | 5 ++++ .../compiler/lib/Analysis/StaticLoops.cpp | 26 +++++++++++++++++++ .../compiler/lib/Transforms/Batching.cpp | 14 +++------- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h index 559e7136b9..0428ef19d2 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h @@ -7,6 +7,7 @@ #define CONCRETELANG_ANALYSIS_STATIC_LOOPS_H #include +#include namespace mlir { namespace concretelang { @@ -60,6 +61,10 @@ bool isConstantIndexValue(mlir::Value v); int64_t getConstantIndexValue(mlir::Value v); bool isConstantIndexValue(mlir::Value v, int64_t i); +mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, + mlir::Value iv, mlir::OpFoldResult lb, + mlir::OpFoldResult step); + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp index 4a22a588cd..8383d9bb5a 100644 --- a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp +++ b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp @@ -5,6 +5,8 @@ #include #include +#include +#include #include #include @@ -446,5 +448,29 @@ bool isConstantIndexValue(mlir::Value v, int64_t i) { return isConstantIndexValue(v) && getConstantIndexValue(v) == i; } +// Returns a `Value` corresponding to `iv`, normalized to the lower +// bound `lb` and step `step` of a loop (i.e., (iv - lb) / step). +mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, + mlir::Value iv, mlir::OpFoldResult lb, + mlir::OpFoldResult step) { + std::optional lbInt = mlir::getConstantIntValue(lb); + std::optional stepInt = mlir::getConstantIntValue(step); + + mlir::Value idxShifted = lbInt.has_value() && *lbInt == 0 + ? iv + : builder.create( + iv, mlir::getValueOrCreateConstantIndexOp( + builder, builder.getLoc(), lb)); + + mlir::Value normalizedIV = + stepInt.has_value() && *stepInt == 1 + ? idxShifted + : builder.create( + idxShifted, mlir::getValueOrCreateConstantIndexOp( + builder, builder.getLoc(), step)); + + return normalizedIV; +} + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp b/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp index 01bec55e56..2d30399492 100644 --- a/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp +++ b/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp @@ -516,18 +516,10 @@ buildNormalizedIndexes(mlir::PatternRewriter &rewriter, rewriter.setInsertionPointToStart(innermost.getBody()); for (mlir::scf::ForOp forOp : nest) { + mlir::ImplicitLocOpBuilder ilob(forOp.getLoc(), rewriter); - mlir::Value idxShifted = - isConstantIndexValue(forOp.getLowerBound(), 0) - ? forOp.getInductionVar() - : rewriter.create(innermost.getLoc(), - forOp.getInductionVar(), - forOp.getLowerBound()); - mlir::Value idx = - isConstantIndexValue(forOp.getStep(), 1) - ? idxShifted - : rewriter.create( - innermost.getLoc(), idxShifted, forOp.getStep()); + mlir::Value idx = normalizeInductionVar( + ilob, forOp.getInductionVar(), forOp.getLowerBound(), forOp.getStep()); res.push_back(idx); } From 93bb849233d98d6397139e306cda7c61bcc2ef5a Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 06:12:12 +0200 Subject: [PATCH 04/17] refactor(compiler): Use reinstantiating conversion patterns for RT operations --- .../Conversion/Utils/RTOpConverter.h | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h index 1d25b1d384..559b28bab6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h @@ -6,8 +6,8 @@ #ifndef CONCRETELANG_CONVERSION_RTOPCONVERTER_H_ #define CONCRETELANG_CONVERSION_RTOPCONVERTER_H_ -#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" #include "concretelang/Conversion/Utils/Legality.h" +#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -20,25 +20,25 @@ populateWithRTTypeConverterPatterns(mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, mlir::TypeConverter &converter) { patterns.add< - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DataflowTaskOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DataflowYieldOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::MakeReadyFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::AwaitFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::CreateAsyncTaskOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::concretelang::RT::CreateAsyncTaskOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::WorkFunctionReturnOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>( patterns.getContext(), converter); From 74efe8be0a35d48f45a892030ec99fc95411e504 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 06:26:31 +0200 Subject: [PATCH 05/17] feat(compiler): Declare RT futures usable as element types for memrefs --- .../compiler/include/concretelang/Dialect/RT/IR/RTTypes.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td index bc0ebaea93..687a318f56 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td @@ -7,7 +7,7 @@ include "mlir/IR/BuiltinTypes.td" class RT_Type traits = []> : TypeDef { } -def RT_Future : RT_Type<"Future"> { +def RT_Future : RT_Type<"Future", [MemRefElementTypeInterface]> { let mnemonic = "future"; let summary = "Future with a parameterized element type"; From a855e2bef6de1ca6034fa03932695673771a3e26 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 06:29:22 +0200 Subject: [PATCH 06/17] refactor(compiler): Make type conversion in scalar FHE to TFHE conversion recursive --- .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index a816faeaa8..2a0e30e901 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -84,29 +84,6 @@ TFHE::GLWECipherTextType convertEncrypted(mlir::MLIRContext *context, return TFHE::GLWECipherTextType::get(context, TFHE::GLWESecretKey()); } -/// Converts `Tensor` into a -/// `Tensor` if the element type is appropriate. -/// Otherwise return the input type. -mlir::Type -maybeConvertEncryptedTensor(mlir::MLIRContext *context, - mlir::RankedTensorType maybeEncryptedTensor) { - if (!maybeEncryptedTensor.getElementType().isa()) { - return (mlir::Type)(maybeEncryptedTensor); - } - auto currentShape = maybeEncryptedTensor.getShape(); - return mlir::RankedTensorType::get( - currentShape, - TFHE::GLWECipherTextType::get(context, TFHE::GLWESecretKey())); -} - -/// Converts any encrypted type to `TFHE::GlweCiphetext` if the -/// input type is appropriate. Otherwise return the input type. -mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t) { - if (auto eint = t.dyn_cast()) - return convertEncrypted(context, eint); - return t; -} - /// The type converter used to convert `FHE` to `TFHE` types using the scalar /// strategy. class TypeConverter : public mlir::TypeConverter { @@ -114,6 +91,21 @@ class TypeConverter : public mlir::TypeConverter { public: TypeConverter() { addConversion([](mlir::Type type) { return type; }); + addConversion([&](mlir::FunctionType type) { + llvm::SmallVector inputTypes; + llvm::SmallVector resultTypes; + + for (mlir::Type inputType : type.getInputs()) + inputTypes.push_back(this->convertType(inputType)); + + for (mlir::Type resultType : type.getResults()) + resultTypes.push_back(this->convertType(resultType)); + + mlir::Type res = + mlir::FunctionType::get(type.getContext(), inputTypes, resultTypes); + + return res; + }); addConversion([](FHE::FheIntegerInterface type) { return convertEncrypted(type.getContext(), type); }); @@ -121,8 +113,9 @@ class TypeConverter : public mlir::TypeConverter { return TFHE::GLWECipherTextType::get(type.getContext(), TFHE::GLWESecretKey()); }); - addConversion([](mlir::RankedTensorType type) { - return maybeConvertEncryptedTensor(type.getContext(), type); + addConversion([&](mlir::RankedTensorType type) { + return mlir::RankedTensorType::get( + type.getShape(), this->convertType(type.getElementType())); }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( From 9216c617e42cc9da78084a96d31c775e9708b942 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 06:31:43 +0200 Subject: [PATCH 07/17] refactor(compiler): Make type conversion in TFHE global parametrization recursive --- .../TFHEGlobalParametrization.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 074075b8f3..00f27ce9a7 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -53,13 +53,8 @@ class TFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter { } }); addConversion([&](mlir::RankedTensorType type) { - auto glwe = type.getElementType().dyn_cast_or_null(); - if (glwe == nullptr || !glwe.getKey().isNone()) { - return (mlir::Type)(type); - } - mlir::Type r = mlir::RankedTensorType::get(type.getShape(), - this->glweInterPBSType(glwe)); - return r; + return mlir::RankedTensorType::get( + type.getShape(), this->convertType(type.getElementType())); }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( From 68d0014218145ae2c13f50b6a7621a3936f3f562 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 06:34:48 +0200 Subject: [PATCH 08/17] feat(compiler): Support non-ciphertext types in TFHE to Concrete conversion patterns Some of the TFHE to Concrete conversion patterns implicitly assume that operands are ciphertexts and thus that the converted types have a higher number of dimensions than the original types. However, for non-ciphertext types, the number of dimensions before and after the conversion must be the same. This commit adds a check to the respective conversion patterns triggering a simple type conversion that preserves the number of dimensions for non-ciphertext types. --- .../TFHEToConcrete/TFHEToConcrete.cpp | 69 ++++++++++++++----- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index 94677739c9..7e57cc6c87 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -59,7 +59,9 @@ class TFHEToConcreteTypeConverter : public mlir::TypeConverter { addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); if (glwe == nullptr) { - return (mlir::Type)(type); + return mlir::RankedTensorType::get( + type.getShape(), this->convertType(type.getElementType())) + .cast(); } mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); @@ -458,9 +460,21 @@ struct ExtractOpPattern return mlir::failure(); } - auto newResultType = this->getTypeConverter() - ->convertType(extractOp.getType()) - .cast(); + auto newResultType = + this->getTypeConverter()->convertType(extractOp.getType()); + + // If the extraction is not on a tensor of ciphertexts, just + // convert the type and keep the rest as-is. + if (!extractOp.getType().isa()) { + rewriter.replaceOpWithNewOp( + extractOp, newResultType, adaptor.getTensor(), adaptor.getIndices()); + + return mlir::success(); + } + + mlir::RankedTensorType newResultTensorType = + newResultType.cast(); + auto tensorRank = adaptor.getTensor().getType().cast().getRank(); @@ -473,14 +487,15 @@ struct ExtractOpPattern // [1..., nbBlock, lweDimension+1] mlir::SmallVector staticSizes(tensorRank, 1); staticSizes[staticSizes.size() - 1] = - newResultType.getDimSize(newResultType.getRank() - 1); + newResultTensorType.getDimSize(newResultTensorType.getRank() - 1); // [1...] for static_strides mlir::SmallVector staticStrides(tensorRank, 1); rewriter.replaceOpWithNewOp( - extractOp, newResultType, adaptor.getTensor(), adaptor.getIndices(), - mlir::SmallVector{}, mlir::SmallVector{}, + extractOp, newResultTensorType, adaptor.getTensor(), + adaptor.getIndices(), mlir::SmallVector{}, + mlir::SmallVector{}, rewriter.getDenseI64ArrayAttr(staticOffsets), rewriter.getDenseI64ArrayAttr(staticSizes), rewriter.getDenseI64ArrayAttr(staticStrides)); @@ -503,24 +518,34 @@ struct InsertSliceOpPattern : public mlir::OpConversionPattern { ::mlir::LogicalResult matchAndRewrite(OpTy insertSliceOp, typename OpTy::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { + bool needsExtraDimension = insertSliceOp.getDest() + .getType() + .getElementType() + .template isa(); + mlir::RankedTensorType newDestTy = ((mlir::Type)adaptor.getDest().getType()) .cast(); - // add 0 to offsets mlir::SmallVector offsets = getMixedValues( adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter); - offsets.push_back(rewriter.getI64IntegerAttr(0)); - // add lweDimension+1 to sizes mlir::SmallVector sizes = getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter); - sizes.push_back(rewriter.getI64IntegerAttr( - newDestTy.getDimSize(newDestTy.getRank() - 1))); - // add 1 to the strides mlir::SmallVector strides = getMixedValues( adaptor.getStaticStrides(), adaptor.getStrides(), rewriter); - strides.push_back(rewriter.getI64IntegerAttr(1)); + + if (needsExtraDimension) { + // add 0 to offsets + offsets.push_back(rewriter.getI64IntegerAttr(0)); + + // add lweDimension+1 to sizes + sizes.push_back(rewriter.getI64IntegerAttr( + newDestTy.getDimSize(newDestTy.getRank() - 1))); + + // add 1 to the strides + strides.push_back(rewriter.getI64IntegerAttr(1)); + } // replace insert slice-like operation with the new one rewriter.replaceOpWithNewOp(insertSliceOp, adaptor.getSource(), @@ -528,7 +553,7 @@ struct InsertSliceOpPattern : public mlir::OpConversionPattern { strides); return ::mlir::success(); - }; + } }; /// Pattern that rewrites the Insert operation, taking into account the @@ -599,13 +624,23 @@ struct FromElementsOpPattern matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp, mlir::tensor::FromElementsOp::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { + auto converter = this->getTypeConverter(); // is not a tensor of GLWEs that need to be extended with the LWE dimension - if (this->getTypeConverter()->isLegal(fromElementsOp.getType())) { + if (converter->isLegal(fromElementsOp.getType())) { return mlir::failure(); } - auto converter = this->getTypeConverter(); + // If the element type is not directly a cipher text type, the + // shape of the output does not change. In this case, the op type + // can be preserved and only type conversion is necessary. + if (!fromElementsOp.getType().getElementType().isa()) { + rewriter.replaceOpWithNewOp( + fromElementsOp, converter->convertType(fromElementsOp.getType()), + adaptor.getOperands()); + + return mlir::success(); + } auto resultTy = fromElementsOp.getResult().getType(); if (converter->isLegal(resultTy)) { From f668e82f2064d7699a7093fc0a94b4c446f323de Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 06:41:59 +0200 Subject: [PATCH 09/17] refactor(compiler): Make type conversion in RT task bufferization recursive --- .../RT/Analysis/BufferizeDataflowTaskOps.cpp | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp index bcea750ab6..8fd35c8292 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp @@ -35,26 +35,54 @@ namespace concretelang { namespace { -class BufferizeRTTypesConverter - : public mlir::bufferization::BufferizeTypeConverter { +class BufferizeRTTypesConverter : public mlir::TypeConverter { +protected: + bufferization::BufferizeTypeConverter btc; + public: BufferizeRTTypesConverter() { + addConversion([&](mlir::Type type) { return btc.convertType(type); }); + + addConversion([&](mlir::RankedTensorType type) { + return mlir::MemRefType::get(type.getShape(), + this->convertType(type.getElementType())); + }); + + addConversion([&](mlir::UnrankedTensorType type) { + return mlir::UnrankedMemRefType::get( + this->convertType(type.getElementType()), 0); + }); + + addConversion([&](mlir::MemRefType type) { + return mlir::MemRefType::get(type.getShape(), + this->convertType(type.getElementType()), + type.getLayout(), type.getMemorySpace()); + }); + + addConversion([&](mlir::UnrankedMemRefType type) { + return mlir::UnrankedMemRefType::get( + this->convertType(type.getElementType()), type.getMemorySpace()); + }); + addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( - this->convertType(type.dyn_cast() - .getElementType())); + this->convertType(type.getElementType())); }); + addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( - this->convertType(type.dyn_cast() - .getElementType())); + this->convertType(type.getElementType())); }); + addConversion([&](mlir::FunctionType type) { SignatureConversion result(type.getNumInputs()); mlir::SmallVector newResults; + if (failed(this->convertSignatureArgs(type.getInputs(), result)) || - failed(this->convertTypes(type.getResults(), newResults))) + failed(this->convertTypes(type.getResults(), newResults))) { return type; + } + return mlir::FunctionType::get(type.getContext(), result.getConvertedTypes(), newResults); }); From 0c7e3a3518471b786a171c3257292999cd308dd1 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 11:00:33 +0200 Subject: [PATCH 10/17] feat(compiler): Add support for nested Memrefs in memory usage estimator --- .../Dialect/Concrete/Analysis/MemoryUsage.cpp | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp index 4e5d0bf137..b2bf4a814f 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -17,15 +18,28 @@ using namespace mlir::concretelang; using namespace mlir; namespace { +outcome::checked +getBufferSize(mlir::MemRefType bufferType); -int64_t getElementTypeSize(mlir::Type elementType) { +outcome::checked +getElementTypeSize(mlir::Type elementType) { if (auto integerType = mlir::dyn_cast(elementType)) { auto width = integerType.getWidth(); return std::ceil((double)width / 8); - } - if (mlir::dyn_cast(elementType)) { + } else if (mlir::dyn_cast(elementType)) { return 8; + } else if (auto memrefType = mlir::dyn_cast(elementType)) { + return getBufferSize(memrefType); + } else if (auto futureType = mlir::dyn_cast(elementType)) { + auto recSize = getElementTypeSize(futureType.getElementType()); + + if (!recSize) + return recSize.error(); + + // FIXME: Hardcoded size of 8 bytes for a pointer + return 8 * recSize.value(); } + return -1; } @@ -33,7 +47,13 @@ outcome::checked getBufferSize(mlir::MemRefType bufferType) { auto shape = bufferType.getShape(); auto elementType = bufferType.getElementType(); - auto elementSize = getElementTypeSize(elementType); + auto maybeElementSize = getElementTypeSize(elementType); + + if (!maybeElementSize) + return maybeElementSize.error(); + + auto elementSize = maybeElementSize.value(); + if (elementSize == -1) return StringError( "allocation of buffer with a non-supported element-type"); From 48d919bd258cddb276885eb8a71b1e00588944cc Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 11:02:41 +0200 Subject: [PATCH 11/17] feat(compiler): Add support for tensor.{from_elements,dim} operations in TFHE passes --- .../Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp | 9 +++++++++ .../TFHEGlobalParametrization.cpp | 11 +++++++++++ .../TFHEKeyNormalization/TFHEKeyNormalization.cpp | 6 ++++++ .../Conversion/TFHEToConcrete/TFHEToConcrete.cpp | 14 ++++++++------ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 2a0e30e901..564daf4451 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -853,6 +853,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { mlir::scf::ForallOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::EmptyOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::FromElementsOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::DimOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::ParallelInsertSliceOp, true>>(&getContext(), converter); @@ -875,6 +879,8 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::tensor::ParallelInsertSliceOp>(target, converter); @@ -910,6 +916,9 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 00f27ce9a7..4f9942f3f1 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -348,6 +348,17 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + + patterns.add>(&getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + patterns.add>( &getContext(), converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp index 8b739718d8..f8288d7c53 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp @@ -371,6 +371,12 @@ void TFHEKeyNormalizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, typeConverter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, typeConverter); + patterns.add>( &getContext(), typeConverter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index 7e57cc6c87..4b35e8ef49 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -926,11 +926,11 @@ void TFHEToConcretePass::runOnOperation() { mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, mlir::tensor::InsertSliceOp, mlir::tensor::ParallelInsertSliceOp, mlir::tensor::ExpandShapeOp, mlir::tensor::CollapseShapeOp, - mlir::tensor::EmptyOp, mlir::bufferization::AllocTensorOp>( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); + mlir::tensor::EmptyOp, mlir::tensor::FromElementsOp, mlir::tensor::DimOp, + mlir::bufferization::AllocTensorOp>([&](mlir::Operation *op) { + return converter.isLegal(op->getResultTypes()) && + converter.isLegal(op->getOperandTypes()); + }); // rewrite scf for loops if working on illegal types patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< - mlir::tensor::EmptyOp, true>>(&getContext(), converter); + mlir::tensor::EmptyOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::DimOp>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); From 3ad3dcb08ff770e593abdc692f8aa6327efd4cbc Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 12:07:03 +0200 Subject: [PATCH 12/17] refactor(compiler): Use signature conversion for conversion of ops with nested blocks The current scheme used by reinstantiating conversion patterns in `lib/Conversion/Utils/Dialects` for operations with blocks is to create a new operation with empty blocks, to move the operations from the old blocks and then to replace any references to block arguments. However, such in-place updates of the types of block arguments leave conversion patterns for operations nested in the blocks without the ability to determine the original types of values from before the update. This change uses proper signature conversion for block arguments, such that the original types of block arguments with converted types is preserved, while the new types are made available through the dialect conversion infrastructure via the respective adaptors. --- .../concretelang/Conversion/Utils/Utils.h | 6 ++ .../lib/Conversion/Utils/Dialects/SCF.cpp | 81 +++---------------- .../lib/Conversion/Utils/Dialects/Tensor.cpp | 25 +----- .../compiler/lib/Conversion/Utils/Utils.cpp | 28 +++++++ 4 files changed, 48 insertions(+), 92 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h index dc95b47767..b7b7d4c9b0 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h @@ -22,6 +22,12 @@ mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter, mlir::Location loc, mlir::ArrayAttr arrAttr); +mlir::Operation *convertOpWithBlocks(mlir::Operation *op, + mlir::ValueRange newOperands, + mlir::TypeRange newResultTypes, + mlir::TypeConverter &typeConverter, + mlir::ConversionPatternRewriter &rewriter); + } // namespace concretelang } // namespace mlir #endif diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp index 0f002782d3..033afd685d 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp @@ -4,7 +4,7 @@ // for license information. #include "concretelang/Conversion/Utils/Dialects/SCF.h" -#include "mlir/Transforms/RegionUtils.h" +#include "concretelang/Conversion/Utils/Utils.h" namespace mlir { namespace concretelang { @@ -13,29 +13,16 @@ mlir::LogicalResult TypeConvertingReinstantiationPattern::matchAndRewrite( scf::ForOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - // Create new for loop with empty body, but converted iter args - scf::ForOp newForOp = rewriter.create( - oldOp.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), - adaptor.getStep(), adaptor.getInitArgs(), - [&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {}); - - newForOp->setAttrs(adaptor.getAttributes()); - - // Move operations from old for op to new one - auto &newOperations = newForOp.getBody()->getOperations(); - mlir::Block *oldBody = oldOp.getBody(); - - newOperations.splice(newOperations.begin(), oldBody->getOperations(), - oldBody->begin(), oldBody->end()); + mlir::TypeConverter &typeConverter = *getTypeConverter(); + llvm::SmallVector convertedResultTypes; - // Remap iter args and IV - for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(), - newForOp.getBody()->getArguments())) { - replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair), - newForOp.getRegion()); + if (typeConverter.convertTypes(oldOp.getResultTypes(), convertedResultTypes) + .failed()) { + return mlir::failure(); } - rewriter.replaceOp(oldOp, newForOp.getResults()); + convertOpWithBlocks(oldOp, adaptor.getOperands(), convertedResultTypes, + typeConverter, rewriter); return mlir::success(); } @@ -49,56 +36,10 @@ TypeConvertingReinstantiationPattern::matchAndRewrite( scf::ForallOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - // Create new forall operation with empty body, but converted iter - // args - llvm::SmallVector lbs = getMixedValues( - adaptor.getStaticLowerBound(), adaptor.getDynamicLowerBound(), rewriter); - llvm::SmallVector ubs = getMixedValues( - adaptor.getStaticUpperBound(), adaptor.getDynamicUpperBound(), rewriter); - llvm::SmallVector step = getMixedValues( - adaptor.getStaticStep(), adaptor.getDynamicStep(), rewriter); - - rewriter.setInsertionPoint(oldOp); - - scf::ForallOp newForallOp = rewriter.create( - oldOp.getLoc(), lbs, ubs, step, adaptor.getOutputs(), - adaptor.getMapping()); - - newForallOp->setAttrs(adaptor.getAttributes()); - - // Move operations from old for op to new one - auto &newOperations = newForallOp.getBody()->getOperations(); - mlir::Block *oldBody = oldOp.getBody(); - - newOperations.splice(newOperations.begin(), oldBody->getOperations(), - oldBody->begin(), std::prev(oldBody->end())); - - // Move operations from `scf.forall.in_parallel` terminator of the - // old op to the terminator of the new op - - mlir::scf::InParallelOp oldTerminator = - llvm::dyn_cast(*std::prev(oldBody->end())); - - assert(oldTerminator && "Last operation of `scf.forall` op expected be a " - "`scf.forall.in_parallel` op"); - - mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator(); - - mlir::Block::OpListType &oldTerminatorOps = - oldTerminator.getRegion().getBlocks().begin()->getOperations(); - mlir::Block::OpListType &newTerminatorOps = - newTerminator.getRegion().getBlocks().begin()->getOperations(); - - newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps, - oldTerminatorOps.begin(), oldTerminatorOps.end()); - - // Remap iter args and IV - for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(), - newForallOp.getBody()->getArguments())) { - std::get<0>(argsPair).replaceAllUsesWith(std::get<1>(argsPair)); - } - rewriter.replaceOp(oldOp, newForallOp.getResults()); + convertOpWithBlocks(oldOp, adaptor.getOperands(), + adaptor.getOutputs().getTypes(), *getTypeConverter(), + rewriter); return mlir::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp index 52edb1b588..7e0670acbc 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp @@ -4,8 +4,7 @@ // for license information. #include "concretelang/Conversion/Utils/Dialects/Tensor.h" -#include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/STLExtras.h" +#include "concretelang/Conversion/Utils/Utils.h" namespace mlir { namespace concretelang { @@ -79,26 +78,8 @@ TypeConvertingReinstantiationPattern::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); - rewriter.setInsertionPointAfter(oldOp); - tensor::GenerateOp newGenerateOp = rewriter.create( - oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs()); - - mlir::Block &oldBlock = oldOp.getBody().getBlocks().front(); - mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front(); - auto begin = oldBlock.begin(); - auto nOps = oldBlock.getOperations().size(); - - newBlock.getOperations().splice(newBlock.getOperations().begin(), - oldBlock.getOperations(), begin, - std::next(begin, nOps - 1)); - - for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(), - newGenerateOp.getRegion().getArguments())) { - replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair), - newGenerateOp.getRegion()); - } - - rewriter.replaceOp(oldOp, newGenerateOp.getResult()); + convertOpWithBlocks(oldOp, adaptor.getOperands(), resultTypes, + *getTypeConverter(), rewriter); return mlir::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp index 719a8e2b91..17d4c5904e 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp @@ -6,6 +6,7 @@ #include "concretelang/Conversion/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { namespace concretelang { @@ -56,5 +57,32 @@ mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter, return mlir::concretelang::getCastedMemRef(rewriter, globalRef); } +// Converts an operation `op` with nested blocks using a type +// converter and a conversion pattern rewriter, such that the newly +// created operation uses the operands specified in `newOperands` and +// returns a value of the types `newResultTypes`. +mlir::Operation * +convertOpWithBlocks(mlir::Operation *op, mlir::ValueRange newOperands, + mlir::TypeRange newResultTypes, + mlir::TypeConverter &typeConverter, + mlir::ConversionPatternRewriter &rewriter) { + mlir::OperationState state(op->getLoc(), op->getName().getStringRef(), + newOperands, newResultTypes, op->getAttrs(), + op->getSuccessors()); + + for (Region ®ion : op->getRegions()) { + Region *newRegion = state.addRegion(); + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + TypeConverter::SignatureConversion result(newRegion->getNumArguments()); + (void)typeConverter.convertSignatureArgs(newRegion->getArgumentTypes(), + result); + rewriter.applySignatureConversion(newRegion, result); + } + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + + return newOp; +} } // namespace concretelang } // namespace mlir From 999c9a9addd9776a514a8dc28605151e1e6b2ad6 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 14:33:21 +0200 Subject: [PATCH 13/17] feat(compiler): Add support for dynamically-sized memrefs in lowering patterns for RT tasks --- .../RT/Analysis/LowerDataflowTasksToRT.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index b76b5c1d29..574cd38e4e 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -568,11 +568,22 @@ struct FinalizeTaskCreationPass .setLayout(AffineMapAttr::get( builder.getMultiDimIdentityMap(rank))); } + + llvm::SmallVector dynamicDimSizes; + + for (auto dimSizeIt : llvm::enumerate(mrType.getShape())) { + if (mlir::ShapedType::isDynamic(dimSizeIt.value())) { + mlir::memref::DimOp dimOp = builder.create( + val.getLoc(), val, dimSizeIt.index()); + dynamicDimSizes.push_back(dimOp.getResult()); + } + } + // We need to make a copy of this MemRef to allow deallocation // based on refcounting - Value newval = - builder.create(val.getLoc(), mrType) - .getResult(); + mlir::memref::AllocOp newval = builder.create( + val.getLoc(), mrType, dynamicDimSizes); + builder.create(val.getLoc(), val, newval); clone = builder.create(op.getLoc(), builder.getI64IntegerAttr(1)); From fd513f1e6e5783cc72e0b1fef29186759b07801b Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Apr 2024 14:35:41 +0200 Subject: [PATCH 14/17] feat(compiler): Add support for various memref operations for RT task bufferization This adds support for `memref.alloc`, `memref.load`, `memref.store`, `memref.copy` and `memref.subview` to the RT task bufferization pass. --- .../RT/Analysis/BufferizeDataflowTaskOps.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp index 8fd35c8292..d8f0ee1676 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp @@ -123,6 +123,26 @@ struct BufferizeDataflowTaskOpsPass mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, typeConverter); + patterns.add, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::LoadOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::StoreOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::CopyOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::SubViewOp, true>>(&getContext(), + typeConverter); + + target.addDynamicallyLegalOp( + [&](mlir::Operation *op) { + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } From d620fa9a44cfdba5926dfbfb56c478e5e9202bff Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 14 Mar 2024 06:30:22 +0100 Subject: [PATCH 15/17] feat(compiler): Add pass hoisting RT.await_future out of scf.forall loops The new pass hoists `RT.await_future` operations whose results are yielded by scf.forall operations out of the loops in order to avoid over-synchronization of data-flow tasks. E.g., the following IR: ``` scf.forall (%arg) in (16) shared_outs(%o1 = %sometensor, %o2 = %someothertensor) -> (tensor<...>, tensor<...>) { ... %rph = "RT.build_return_ptr_placeholder"() : () -> !RT.rtptr>> "RT.create_async_task"(..., %rph, ...) { ... } : ... %future = "RT.deref_return_ptr_placeholder"(%rph) : (!RT.rtptr>) -> !RT.future> %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> ... scf.forall.in_parallel { ... tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] : tensor<...> into tensor<...> ... } } ``` is transformed into: ``` %tensoroffutures = tensor.empty() : tensor<16x!RT.future>> scf.forall (%arg) in (16) shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor) -> (tensor<...>, tensor<...>) { ... %rph = "RT.build_return_ptr_placeholder"() : () -> !RT.rtptr>> "RT.create_async_task"(..., %rph, ...) { ... } : ... %future = "RT.deref_return_ptr_placeholder"(%rph) : (!RT.rtptr>) -> !RT.future> %wrappedfuture = tensor.from_elements %future : tensor<1x!RT.future>> ... scf.forall.in_parallel { ... tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] : tensor<1xRT.future>> into tensor<16x!RT.future>> ... } } scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) { %future = tensor.extract %tensoroffutures[%arg] : tensor<4x!RT.future>> %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> scf.forall.in_parallel { tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] : tensor<...> into tensor<...> } } ``` --- .../concretelang/Analysis/StaticLoops.h | 6 + .../concretelang/Dialect/RT/CMakeLists.txt | 1 + .../Dialect/RT/Transforms/CMakeLists.txt | 3 + .../Dialect/RT/Transforms/Passes.h | 25 ++ .../Dialect/RT/Transforms/Passes.td | 82 ++++++ .../compiler/lib/Analysis/StaticLoops.cpp | 14 + .../TFHEGlobalParametrization.cpp | 8 +- .../TFHEKeyNormalization.cpp | 5 + .../lib/Dialect/RT/Transforms/CMakeLists.txt | 2 + .../RT/Transforms/HoistAwaitFuturePass.cpp | 259 ++++++++++++++++++ .../compiler/lib/Support/Pipeline.cpp | 3 + 11 files changed, 403 insertions(+), 5 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h index 0428ef19d2..829222b0b8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h @@ -65,6 +65,12 @@ mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, mlir::Value iv, mlir::OpFoldResult lb, mlir::OpFoldResult step); +llvm::SmallVector +normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder, + mlir::ValueRange ivs, + llvm::ArrayRef lbs, + llvm::ArrayRef steps); + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt index 4f74948933..306b439685 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..0661dc36df --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name RT) +add_public_tablegen_target(RTTransformsIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h new file mode 100644 index 0000000000..371a7544c2 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h @@ -0,0 +1,25 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H +#define CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" + +#include "concretelang/Dialect/RT/IR/RTDialect.h" + +#define GEN_PASS_CLASSES +#include "concretelang/Dialect/RT/Transforms/Passes.h.inc" + +namespace mlir { +namespace concretelang { +std::unique_ptr> createHoistAwaitFuturePass(); +} // namespace concretelang +} // namespace mlir + +#endif // CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td new file mode 100644 index 0000000000..f638400d58 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td @@ -0,0 +1,82 @@ +#ifndef MLIR_DIALECT_RT_TRANSFORMS_PASSES +#define MLIR_DIALECT_RT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def HoistAwaitFuturePass : Pass<"hoist-await-future", "mlir::func::FuncOp"> { + let summary = "Hoists `RT.await_future` operations whose results are yielded " + "by `scf.forall` operations out of the loops"; + let description = [{ + Hoists `RT.await_future` operations whose results are yielded by + scf.forall operations out of the loops in order to avoid + over-synchronization of data-flow tasks. + + E.g., the following IR: + + ``` + scf.forall (%arg) in (16) + shared_outs(%o1 = %sometensor, %o2 = %someothertensor) + -> (tensor<...>, tensor<...>) + { + ... + %rph = "RT.build_return_ptr_placeholder"() : + () -> !RT.rtptr>> + "RT.create_async_task"(..., %rph, ...) { ... } : ... + %future = "RT.deref_return_ptr_placeholder"(%rph) : + (!RT.rtptr>) -> !RT.future> + %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> + ... + scf.forall.in_parallel { + ... + tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] : + tensor<...> into tensor<...> + ... + } + } + ``` + + is transformed into: + + ``` + %tensoroffutures = tensor.empty() : tensor<16x!RT.future>> + + scf.forall (%arg) in (16) + shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor) + -> (tensor<...>, tensor<...>) + { + ... + %rph = "RT.build_return_ptr_placeholder"() : + () -> !RT.rtptr>> + "RT.create_async_task"(..., %rph, ...) { ... } : ... + %future = "RT.deref_return_ptr_placeholder"(%rph) : + (!RT.rtptr>) -> !RT.future> + %wrappedfuture = tensor.from_elements %future : + tensor<1x!RT.future>> + ... + scf.forall.in_parallel { + ... + tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] : + tensor<1xRT.future>> into tensor<16x!RT.future>> + ... + } + } + + scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) { + %future = tensor.extract %tensoroffutures[%arg] : + tensor<4x!RT.future>> + %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] : + tensor<...> into tensor<...> + } + } + ``` + }]; + let constructor = "mlir::concretelang::createHoistAwaitFuturePass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", "mlir::concretelang::RT::RTDialect" + ]; +} + +#endif // MLIR_DIALECT_RT_TRANSFORMS_PASSES diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp index 8383d9bb5a..2515dbf081 100644 --- a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp +++ b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp @@ -472,5 +472,19 @@ mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, return normalizedIV; } +llvm::SmallVector +normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder, + mlir::ValueRange ivs, + llvm::ArrayRef lbs, + llvm::ArrayRef steps) { + llvm::SmallVector normalizedIVs; + + for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) { + normalizedIVs.push_back(normalizeInductionVar(builder, iv, lb, step)); + } + + return normalizedIVs; +} + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 4f9942f3f1..74d24dd036 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -397,15 +397,13 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::Tracing::TraceCiphertextOp>, mlir::concretelang::GenericTypeConverterPattern, - mlir::concretelang::GenericTypeConverterPattern>( - &getContext(), converter); + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern< + mlir::tensor::ParallelInsertSliceOp>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); - mlir::concretelang::GenericTypeConverterPattern< - mlir::tensor::ParallelInsertSliceOp>(&getContext(), converter); - // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp index f8288d7c53..f6c59b3c13 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp @@ -377,6 +377,11 @@ void TFHEKeyNormalizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, typeConverter); + patterns.add>(&getContext(), typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::tensor::ParallelInsertSliceOp>(target, typeConverter); + patterns.add>( &getContext(), typeConverter); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt index 993e05688d..5168d47ea3 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt @@ -1,9 +1,11 @@ add_mlir_dialect_library( RTDialectTransforms BufferizableOpInterfaceImpl.cpp + HoistAwaitFuturePass.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/RT DEPENDS + RTTransformsIncGen mlir-headers LINK_LIBS PUBLIC diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp new file mode 100644 index 0000000000..f5e9784054 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp @@ -0,0 +1,259 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include + +#include + +#include +#include + +namespace { +struct HoistAwaitFuturePass + : public HoistAwaitFuturePassBase { + // Checks if all values of `a` are sizes of a non-dynamic dimensions + bool allStatic(llvm::ArrayRef a) { + return llvm::all_of( + a, [](int64_t r) { return !mlir::ShapedType::isDynamic(r); }); + } + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + + llvm::SmallVector opsToErase; + + func.walk([&](mlir::concretelang::RT::AwaitFutureOp awaitFutureOp) { + // Make sure there are no other consumers that rely on the + // synchronization + if (!awaitFutureOp.getResult().hasOneUse()) + return; + + mlir::scf::ForallOp forallOp = + llvm::dyn_cast(awaitFutureOp->getParentOp()); + + if (!forallOp) + return; + + mlir::tensor::ParallelInsertSliceOp parallelInsertSliceOp = + llvm::dyn_cast( + awaitFutureOp.getResult().getUses().begin()->getOwner()); + + if (!parallelInsertSliceOp) + return; + + // Make sure that the original tensor into which the + // synchronized values are inserted is a region out argument of + // the forall op and thus being written to concurrently + mlir::Value dst = parallelInsertSliceOp.getDest(); + + if (!llvm::any_of(forallOp.getRegionOutArgs(), + [=](mlir::Value output) { return output == dst; })) + return; + + // Currently, the tensor storing the futures must have a static + // shape, so only loops with static trip counts are supported + if (!(allStatic(forallOp.getStaticLowerBound()) && + allStatic(forallOp.getStaticUpperBound()) && + allStatic(forallOp.getStaticStep()))) + return; + + llvm::SmallVector tripCounts; + + for (auto [lb, ub, step] : llvm::zip_equal(forallOp.getStaticLowerBound(), + forallOp.getStaticUpperBound(), + forallOp.getStaticStep())) { + tripCounts.push_back( + mlir::concretelang::getStaticTripCount(lb, ub, step)); + } + + mlir::IRRewriter rewriter(&getContext()); + rewriter.setInsertionPoint(forallOp); + + mlir::Value tensorOfFutures = rewriter.create( + forallOp.getLoc(), tripCounts, awaitFutureOp.getInput().getType()); + + // Assemble the list of shared outputs that are to be preserved + // after the output storing the results of the `RT.await_future` + // has been removed + llvm::SmallVector newOutputs; + mlir::Value tensorOfValues; + size_t i = 0; + size_t oldResultIdx; + for (auto [output, regionOutArg] : llvm::zip_equal( + forallOp.getOutputs(), forallOp.getRegionOutArgs())) { + if (regionOutArg != dst) { + newOutputs.push_back(output); + } else { + tensorOfValues = output; + oldResultIdx = i; + } + + i++; + } + + newOutputs.push_back(tensorOfFutures); + + // Create a new forall loop with the same shared outputs except + // for the one previously storing the contents of the + // `RT.await_future` ops is replaced with a tensor of futures + rewriter.setInsertionPointAfter(forallOp); + mlir::scf::ForallOp newForallOp = rewriter.create( + forallOp.getLoc(), forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOutputs, + std::nullopt); + + // Move all operations from the old forall op to the new one + auto &newOperations = newForallOp.getBody()->getOperations(); + mlir::Block *oldBody = forallOp.getBody(); + + newOperations.splice(newOperations.begin(), oldBody->getOperations(), + oldBody->begin(), std::prev(oldBody->end())); + + // Wrap future in a tensor of one element, so that it can be + // stored in the new shared output tensor of futures using + // `tensor.parallel_insert_slice` + rewriter.setInsertionPointAfter(awaitFutureOp); + mlir::Value futureAsTensor = + rewriter.create( + awaitFutureOp.getLoc(), + mlir::ValueRange{awaitFutureOp.getInput()}); + + // Move all operations from the old `scf.forall.in_parallel` + // terminator to the new one + mlir::scf::InParallelOp oldTerminator = forallOp.getTerminator(); + mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator(); + + mlir::Block::OpListType &oldTerminatorOps = + oldTerminator.getRegion().getBlocks().begin()->getOperations(); + mlir::Block::OpListType &newTerminatorOps = + newTerminator.getRegion().getBlocks().begin()->getOperations(); + + newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps, + oldTerminatorOps.begin(), oldTerminatorOps.end()); + + // Remap IVs and out args + for (auto [oldIV, newIV] : llvm::zip(forallOp.getInductionVars(), + newForallOp.getInductionVars())) { + oldIV.replaceAllUsesWith(newIV); + } + + { + size_t offs = 0; + for (auto it : llvm::enumerate(forallOp.getRegionOutArgs())) { + mlir::Value oldRegionOutArg = it.value(); + + if (oldRegionOutArg != dst) { + oldRegionOutArg.replaceAllUsesWith( + newForallOp.getRegionOutArgs()[it.index() - offs]); + } else { + offs++; + } + } + } + + // Create new `tensor.parallel_inset_slice` operation inserting + // the future into the tensor of futures + llvm::SmallVector ones(tripCounts.size(), + rewriter.getI64IntegerAttr(1)); + + mlir::Value tensorOfFuturesRegionOutArg = + newForallOp.getRegionOutArgs().back(); + + mlir::ImplicitLocOpBuilder ilob(parallelInsertSliceOp.getLoc(), rewriter); + + rewriter.setInsertionPointAfter(parallelInsertSliceOp); + rewriter.create( + parallelInsertSliceOp.getLoc(), futureAsTensor, + tensorOfFuturesRegionOutArg, + mlir::getAsOpFoldResult(mlir::concretelang::normalizeInductionVars( + ilob, newForallOp.getInductionVars(), + newForallOp.getMixedLowerBound(), newForallOp.getMixedStep())), + ones, ones); + + // Create a new forall loop, that invokes `RT.await_future` on + // all futures stored in the tensor of futures and writes the + // contents into the otiginal tensor with the results + rewriter.setInsertionPointAfter(newForallOp); + mlir::scf::ForallOp syncForallOp = rewriter.create( + forallOp.getLoc(), forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), + mlir::ValueRange{tensorOfValues}, std::nullopt); + + mlir::Value resultTensorOfFutures = newForallOp.getResults().back(); + + rewriter.setInsertionPointToStart(syncForallOp.getBody()); + mlir::Value extractedFuture = rewriter.create( + awaitFutureOp.getLoc(), resultTensorOfFutures, + syncForallOp.getInductionVars()); + mlir::concretelang::RT::AwaitFutureOp newAwaitFutureOp = + rewriter.create( + awaitFutureOp.getLoc(), awaitFutureOp.getResult().getType(), + extractedFuture); + + mlir::IRMapping syncMapping; + + for (auto [oldIV, newIV] : + llvm::zip_equal(newForallOp.getInductionVars(), + syncForallOp.getInductionVars())) { + syncMapping.map(oldIV, newIV); + } + + syncMapping.map(dst, syncForallOp.getOutputBlockArguments().back()); + syncMapping.map(parallelInsertSliceOp.getSource(), + newAwaitFutureOp.getResult()); + + mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator(); + rewriter.setInsertionPointToStart(syncTerminator.getBody()); + rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping); + + // Replace uses of the results of the original forall loop with: + // either the corresponding result from the new forall loop if + // this is a result unrelated to the futures or with the result + // of the forall loop synchronizing the futures + { + size_t offs = 0; + for (size_t i = 0; i < forallOp.getNumResults(); i++) { + if (i == oldResultIdx) { + forallOp.getResult(i).replaceAllUsesWith(syncForallOp.getResult(0)); + offs = 1; + } else { + forallOp.getResult(i).replaceAllUsesWith( + newForallOp.getResult(i - offs)); + } + } + } + + // Replace the use of the shared output with the results of the + // original forall loop with the tensor outside of the loop so + // that there are no more references to values that were local + // to the original forall loop, enabling safe erasing of the old + // operations within the original forall loop + dst.replaceAllUsesWith( + forallOp.getOutputs().drop_front(oldResultIdx).front()); + parallelInsertSliceOp->erase(); + awaitFutureOp.erase(); + + // Defer erasing the original parallel loop that contained the + // `RT.await_future` operation until later in order to not + // confuse the walk relying on the parent operation + opsToErase.push_back(forallOp); + }); + + for (mlir::Operation *op : opsToErase) + op->erase(); + } +}; +} // namespace + +namespace mlir { +namespace concretelang { +std::unique_ptr> createHoistAwaitFuturePass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index c1f6187b1a..210d2ea03d 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -44,6 +44,7 @@ #include "concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h" #include "concretelang/Dialect/FHELinalg/Transforms/Tiling.h" #include "concretelang/Dialect/RT/Analysis/Autopar.h" +#include "concretelang/Dialect/RT/Transforms/Passes.h" #include "concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h" #include "concretelang/Dialect/TFHE/Transforms/Transforms.h" #include "concretelang/Support/CompilerEngine.h" @@ -183,6 +184,8 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, pm, mlir::concretelang::createBuildDataflowTaskGraphPass(), enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass); + addPotentiallyNestedPass(pm, mlir::concretelang::createHoistAwaitFuturePass(), + enablePass); return pm.run(module.getOperation()); } From 7cf5483425f8b02337ee349311137fd987307417 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 4 Apr 2024 11:54:37 +0200 Subject: [PATCH 16/17] test(compiler): Add tests for tiling generating partial tiles --- .../tests_cpu/end_to_end_fhelinalg.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml index 005efce9fd..8f448c6426 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml @@ -2152,6 +2152,24 @@ tests: - tensor: [16,21,44,57,12,23,30,39,58,55,16,21,44,57,12,23] shape: [8,2] +--- +description: tiled_matmul_eint_int_1_1_7_partial_tiles +program: | + func.func @main(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8x2x!FHE.eint<6>> { + %0 = "FHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [0,0,7] } : (tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!FHE.eint<6>> + return %0 : tensor<8x2x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2] + shape: [8,4] + - tensor: [1,2,3,4,3,1,0,2] + shape: [4,2] + width: 8 + outputs: + - tensor: [16,21,44,57,12,23,30,39,58,55,16,21,44,57,12,23] + shape: [8,2] + --- description: extract_slice_parametric_2x2 program: | From f506f5f7e3b1cb5f3b3880df35c3133bb4f28ae3 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 8 Apr 2024 15:41:04 +0200 Subject: [PATCH 17/17] test(compiler): Add check tests for pass hoisting RT.await_future operations --- .../Dialect/RT/hoist_await_future.mlir | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future.mlir diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future.mlir new file mode 100644 index 0000000000..e6748d3224 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future.mlir @@ -0,0 +1,20 @@ +// RUN: concretecompiler --action=dump-fhe-df-parallelized %s --optimizer-strategy=dag-mono --parallelize | FileCheck %s +// RUN: concretecompiler --action=dump-llvm-ir %s --optimizer-strategy=dag-mono --parallelize +// RUN: concretecompiler --action=dump-llvm-ir %s --optimizer-strategy=dag-multi --parallelize + +// CHECK: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %from_elements into %arg3[%arg2] [1] [1] : tensor<1x!RT.future>>> into tensor<4x!RT.future>>> +// CHECK-NEXT: } +// +// CHECK: %[[res:.*]] = scf.forall (%[[arg:.*]]) in (4) shared_outs(%[[so:.*]] = %[[init:.*]]) -> (tensor<8x9x4x!FHE.eint<6>>) { +// CHECK-NEXT: %[[extracted:.*]] = tensor.extract %4[%[[arg]]] : tensor<4x!RT.future>>> +// CHECK-NEXT: %[[awaitres:.*]] = "RT.await_future"(%[[extracted]]) : (!RT.future>>) -> tensor<8x9x!FHE.eint<6>> +// CHECK-NEXT: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %[[awaitres]] into %[[so]][0, 0, %[[arg]]] [8, 9, 1] [1, 1, 1] : tensor<8x9x!FHE.eint<6>> into tensor<8x9x4x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +func.func @main(%a: tensor<8x7x!FHE.eint<6>>, %b: tensor<7x9xi7>) -> tensor<8x9x!FHE.eint<6>>{ + %0 = "FHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [0, 0, 2] } : (tensor<8x7x!FHE.eint<6>>, tensor<7x9xi7>) -> tensor<8x9x!FHE.eint<6>> + return %0 : tensor<8x9x!FHE.eint<6>> +}