diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h index 559e7136b9..829222b0b8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h @@ -7,6 +7,7 @@ #define CONCRETELANG_ANALYSIS_STATIC_LOOPS_H #include +#include namespace mlir { namespace concretelang { @@ -60,6 +61,16 @@ bool isConstantIndexValue(mlir::Value v); int64_t getConstantIndexValue(mlir::Value v); bool isConstantIndexValue(mlir::Value v, int64_t i); +mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, + mlir::Value iv, mlir::OpFoldResult lb, + mlir::OpFoldResult step); + +llvm::SmallVector +normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder, + mlir::ValueRange ivs, + llvm::ArrayRef lbs, + llvm::ArrayRef steps); + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h index 1d25b1d384..559b28bab6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/RTOpConverter.h @@ -6,8 +6,8 @@ #ifndef CONCRETELANG_CONVERSION_RTOPCONVERTER_H_ #define CONCRETELANG_CONVERSION_RTOPCONVERTER_H_ -#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" #include "concretelang/Conversion/Utils/Legality.h" +#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -20,25 +20,25 @@ populateWithRTTypeConverterPatterns(mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, mlir::TypeConverter &converter) { patterns.add< - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DataflowTaskOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DataflowYieldOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::MakeReadyFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::AwaitFutureOp>, - mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::CreateAsyncTaskOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::concretelang::RT::CreateAsyncTaskOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::WorkFunctionReturnOp>, - mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>( patterns.getContext(), converter); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h index dc95b47767..b7b7d4c9b0 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h @@ -22,6 +22,12 @@ mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter, mlir::Location loc, mlir::ArrayAttr arrAttr); +mlir::Operation *convertOpWithBlocks(mlir::Operation *op, + mlir::ValueRange newOperands, + mlir::TypeRange newResultTypes, + mlir::TypeConverter &typeConverter, + mlir::ConversionPatternRewriter &rewriter); + } // namespace concretelang } // namespace mlir #endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt index 4f74948933..306b439685 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td index bc0ebaea93..687a318f56 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/IR/RTTypes.td @@ -7,7 +7,7 @@ include "mlir/IR/BuiltinTypes.td" class RT_Type traits = []> : TypeDef { } -def RT_Future : RT_Type<"Future"> { +def RT_Future : RT_Type<"Future", [MemRefElementTypeInterface]> { let mnemonic = "future"; let summary = "Future with a parameterized element type"; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..0661dc36df --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name RT) +add_public_tablegen_target(RTTransformsIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h new file mode 100644 index 0000000000..371a7544c2 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h @@ -0,0 +1,25 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H +#define CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" + +#include "concretelang/Dialect/RT/IR/RTDialect.h" + +#define GEN_PASS_CLASSES +#include "concretelang/Dialect/RT/Transforms/Passes.h.inc" + +namespace mlir { +namespace concretelang { +std::unique_ptr> createHoistAwaitFuturePass(); +} // namespace concretelang +} // namespace mlir + +#endif // CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td new file mode 100644 index 0000000000..f638400d58 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td @@ -0,0 +1,82 @@ +#ifndef MLIR_DIALECT_RT_TRANSFORMS_PASSES +#define MLIR_DIALECT_RT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def HoistAwaitFuturePass : Pass<"hoist-await-future", "mlir::func::FuncOp"> { + let summary = "Hoists `RT.await_future` operations whose results are yielded " + "by `scf.forall` operations out of the loops"; + let description = [{ + Hoists `RT.await_future` operations whose results are yielded by + scf.forall operations out of the loops in order to avoid + over-synchronization of data-flow tasks. + + E.g., the following IR: + + ``` + scf.forall (%arg) in (16) + shared_outs(%o1 = %sometensor, %o2 = %someothertensor) + -> (tensor<...>, tensor<...>) + { + ... + %rph = "RT.build_return_ptr_placeholder"() : + () -> !RT.rtptr>> + "RT.create_async_task"(..., %rph, ...) { ... } : ... + %future = "RT.deref_return_ptr_placeholder"(%rph) : + (!RT.rtptr>) -> !RT.future> + %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> + ... + scf.forall.in_parallel { + ... + tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] : + tensor<...> into tensor<...> + ... + } + } + ``` + + is transformed into: + + ``` + %tensoroffutures = tensor.empty() : tensor<16x!RT.future>> + + scf.forall (%arg) in (16) + shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor) + -> (tensor<...>, tensor<...>) + { + ... + %rph = "RT.build_return_ptr_placeholder"() : + () -> !RT.rtptr>> + "RT.create_async_task"(..., %rph, ...) { ... } : ... + %future = "RT.deref_return_ptr_placeholder"(%rph) : + (!RT.rtptr>) -> !RT.future> + %wrappedfuture = tensor.from_elements %future : + tensor<1x!RT.future>> + ... + scf.forall.in_parallel { + ... + tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] : + tensor<1xRT.future>> into tensor<16x!RT.future>> + ... + } + } + + scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) { + %future = tensor.extract %tensoroffutures[%arg] : + tensor<4x!RT.future>> + %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] : + tensor<...> into tensor<...> + } + } + ``` + }]; + let constructor = "mlir::concretelang::createHoistAwaitFuturePass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", "mlir::concretelang::RT::RTDialect" + ]; +} + +#endif // MLIR_DIALECT_RT_TRANSFORMS_PASSES diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h index 68283185ec..e1fd0772ee 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h @@ -252,6 +252,10 @@ class CompilerEngine { /// from the Linalg dialect FHE_LINALG_GENERIC, + /// Read sources and lower all the FHELinalg operations to FHE + /// operations, dump after data-flow parallelization + FHE_DF_PARALLELIZED, + /// Read sources and lower all the FHELinalg operations to FHE operations /// and scf loops FHE_NO_LINALG, diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp index 4a22a588cd..2515dbf081 100644 --- a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp +++ b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp @@ -5,6 +5,8 @@ #include #include +#include +#include #include #include @@ -446,5 +448,43 @@ bool isConstantIndexValue(mlir::Value v, int64_t i) { return isConstantIndexValue(v) && getConstantIndexValue(v) == i; } +// Returns a `Value` corresponding to `iv`, normalized to the lower +// bound `lb` and step `step` of a loop (i.e., (iv - lb) / step). +mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, + mlir::Value iv, mlir::OpFoldResult lb, + mlir::OpFoldResult step) { + std::optional lbInt = mlir::getConstantIntValue(lb); + std::optional stepInt = mlir::getConstantIntValue(step); + + mlir::Value idxShifted = lbInt.has_value() && *lbInt == 0 + ? iv + : builder.create( + iv, mlir::getValueOrCreateConstantIndexOp( + builder, builder.getLoc(), lb)); + + mlir::Value normalizedIV = + stepInt.has_value() && *stepInt == 1 + ? idxShifted + : builder.create( + idxShifted, mlir::getValueOrCreateConstantIndexOp( + builder, builder.getLoc(), step)); + + return normalizedIV; +} + +llvm::SmallVector +normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder, + mlir::ValueRange ivs, + llvm::ArrayRef lbs, + llvm::ArrayRef steps) { + llvm::SmallVector normalizedIVs; + + for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) { + normalizedIVs.push_back(normalizeInductionVar(builder, iv, lb, step)); + } + + return normalizedIVs; +} + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 3dc83a8ca2..1fd921355d 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -697,6 +697,9 @@ struct FHELinalgApplyLookupTableToLinalgGeneric outs, maps, iteratorTypes, doc, call, bodyBuilder); + if (lutOp->hasAttr("tile-sizes")) + genericOp->setAttr("tile-sizes", lutOp->getAttr("tile-sizes")); + rewriter.replaceOp(lutOp, {genericOp.getResult(0)}); return ::mlir::success(); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index a816faeaa8..564daf4451 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -84,29 +84,6 @@ TFHE::GLWECipherTextType convertEncrypted(mlir::MLIRContext *context, return TFHE::GLWECipherTextType::get(context, TFHE::GLWESecretKey()); } -/// Converts `Tensor` into a -/// `Tensor` if the element type is appropriate. -/// Otherwise return the input type. -mlir::Type -maybeConvertEncryptedTensor(mlir::MLIRContext *context, - mlir::RankedTensorType maybeEncryptedTensor) { - if (!maybeEncryptedTensor.getElementType().isa()) { - return (mlir::Type)(maybeEncryptedTensor); - } - auto currentShape = maybeEncryptedTensor.getShape(); - return mlir::RankedTensorType::get( - currentShape, - TFHE::GLWECipherTextType::get(context, TFHE::GLWESecretKey())); -} - -/// Converts any encrypted type to `TFHE::GlweCiphetext` if the -/// input type is appropriate. Otherwise return the input type. -mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t) { - if (auto eint = t.dyn_cast()) - return convertEncrypted(context, eint); - return t; -} - /// The type converter used to convert `FHE` to `TFHE` types using the scalar /// strategy. class TypeConverter : public mlir::TypeConverter { @@ -114,6 +91,21 @@ class TypeConverter : public mlir::TypeConverter { public: TypeConverter() { addConversion([](mlir::Type type) { return type; }); + addConversion([&](mlir::FunctionType type) { + llvm::SmallVector inputTypes; + llvm::SmallVector resultTypes; + + for (mlir::Type inputType : type.getInputs()) + inputTypes.push_back(this->convertType(inputType)); + + for (mlir::Type resultType : type.getResults()) + resultTypes.push_back(this->convertType(resultType)); + + mlir::Type res = + mlir::FunctionType::get(type.getContext(), inputTypes, resultTypes); + + return res; + }); addConversion([](FHE::FheIntegerInterface type) { return convertEncrypted(type.getContext(), type); }); @@ -121,8 +113,9 @@ class TypeConverter : public mlir::TypeConverter { return TFHE::GLWECipherTextType::get(type.getContext(), TFHE::GLWESecretKey()); }); - addConversion([](mlir::RankedTensorType type) { - return maybeConvertEncryptedTensor(type.getContext(), type); + addConversion([&](mlir::RankedTensorType type) { + return mlir::RankedTensorType::get( + type.getShape(), this->convertType(type.getElementType())); }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( @@ -860,6 +853,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { mlir::scf::ForallOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::EmptyOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::FromElementsOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::DimOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::ParallelInsertSliceOp, true>>(&getContext(), converter); @@ -882,6 +879,8 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::tensor::ParallelInsertSliceOp>(target, converter); @@ -917,6 +916,9 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 074075b8f3..74d24dd036 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -53,13 +53,8 @@ class TFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter { } }); addConversion([&](mlir::RankedTensorType type) { - auto glwe = type.getElementType().dyn_cast_or_null(); - if (glwe == nullptr || !glwe.getKey().isNone()) { - return (mlir::Type)(type); - } - mlir::Type r = mlir::RankedTensorType::get(type.getShape(), - this->glweInterPBSType(glwe)); - return r; + return mlir::RankedTensorType::get( + type.getShape(), this->convertType(type.getElementType())); }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( @@ -353,6 +348,17 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + + patterns.add>(&getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + patterns.add>( &getContext(), converter); @@ -391,15 +397,13 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::Tracing::TraceCiphertextOp>, mlir::concretelang::GenericTypeConverterPattern, - mlir::concretelang::GenericTypeConverterPattern>( - &getContext(), converter); + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern< + mlir::tensor::ParallelInsertSliceOp>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); - mlir::concretelang::GenericTypeConverterPattern< - mlir::tensor::ParallelInsertSliceOp>(&getContext(), converter); - // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp index 8b739718d8..f6c59b3c13 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp @@ -371,6 +371,17 @@ void TFHEKeyNormalizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, typeConverter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, typeConverter); + + patterns.add>(&getContext(), typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::tensor::ParallelInsertSliceOp>(target, typeConverter); + patterns.add>( &getContext(), typeConverter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index 94677739c9..4b35e8ef49 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -59,7 +59,9 @@ class TFHEToConcreteTypeConverter : public mlir::TypeConverter { addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); if (glwe == nullptr) { - return (mlir::Type)(type); + return mlir::RankedTensorType::get( + type.getShape(), this->convertType(type.getElementType())) + .cast(); } mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); @@ -458,9 +460,21 @@ struct ExtractOpPattern return mlir::failure(); } - auto newResultType = this->getTypeConverter() - ->convertType(extractOp.getType()) - .cast(); + auto newResultType = + this->getTypeConverter()->convertType(extractOp.getType()); + + // If the extraction is not on a tensor of ciphertexts, just + // convert the type and keep the rest as-is. + if (!extractOp.getType().isa()) { + rewriter.replaceOpWithNewOp( + extractOp, newResultType, adaptor.getTensor(), adaptor.getIndices()); + + return mlir::success(); + } + + mlir::RankedTensorType newResultTensorType = + newResultType.cast(); + auto tensorRank = adaptor.getTensor().getType().cast().getRank(); @@ -473,14 +487,15 @@ struct ExtractOpPattern // [1..., nbBlock, lweDimension+1] mlir::SmallVector staticSizes(tensorRank, 1); staticSizes[staticSizes.size() - 1] = - newResultType.getDimSize(newResultType.getRank() - 1); + newResultTensorType.getDimSize(newResultTensorType.getRank() - 1); // [1...] for static_strides mlir::SmallVector staticStrides(tensorRank, 1); rewriter.replaceOpWithNewOp( - extractOp, newResultType, adaptor.getTensor(), adaptor.getIndices(), - mlir::SmallVector{}, mlir::SmallVector{}, + extractOp, newResultTensorType, adaptor.getTensor(), + adaptor.getIndices(), mlir::SmallVector{}, + mlir::SmallVector{}, rewriter.getDenseI64ArrayAttr(staticOffsets), rewriter.getDenseI64ArrayAttr(staticSizes), rewriter.getDenseI64ArrayAttr(staticStrides)); @@ -503,24 +518,34 @@ struct InsertSliceOpPattern : public mlir::OpConversionPattern { ::mlir::LogicalResult matchAndRewrite(OpTy insertSliceOp, typename OpTy::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { + bool needsExtraDimension = insertSliceOp.getDest() + .getType() + .getElementType() + .template isa(); + mlir::RankedTensorType newDestTy = ((mlir::Type)adaptor.getDest().getType()) .cast(); - // add 0 to offsets mlir::SmallVector offsets = getMixedValues( adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter); - offsets.push_back(rewriter.getI64IntegerAttr(0)); - // add lweDimension+1 to sizes mlir::SmallVector sizes = getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter); - sizes.push_back(rewriter.getI64IntegerAttr( - newDestTy.getDimSize(newDestTy.getRank() - 1))); - // add 1 to the strides mlir::SmallVector strides = getMixedValues( adaptor.getStaticStrides(), adaptor.getStrides(), rewriter); - strides.push_back(rewriter.getI64IntegerAttr(1)); + + if (needsExtraDimension) { + // add 0 to offsets + offsets.push_back(rewriter.getI64IntegerAttr(0)); + + // add lweDimension+1 to sizes + sizes.push_back(rewriter.getI64IntegerAttr( + newDestTy.getDimSize(newDestTy.getRank() - 1))); + + // add 1 to the strides + strides.push_back(rewriter.getI64IntegerAttr(1)); + } // replace insert slice-like operation with the new one rewriter.replaceOpWithNewOp(insertSliceOp, adaptor.getSource(), @@ -528,7 +553,7 @@ struct InsertSliceOpPattern : public mlir::OpConversionPattern { strides); return ::mlir::success(); - }; + } }; /// Pattern that rewrites the Insert operation, taking into account the @@ -599,13 +624,23 @@ struct FromElementsOpPattern matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp, mlir::tensor::FromElementsOp::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { + auto converter = this->getTypeConverter(); // is not a tensor of GLWEs that need to be extended with the LWE dimension - if (this->getTypeConverter()->isLegal(fromElementsOp.getType())) { + if (converter->isLegal(fromElementsOp.getType())) { return mlir::failure(); } - auto converter = this->getTypeConverter(); + // If the element type is not directly a cipher text type, the + // shape of the output does not change. In this case, the op type + // can be preserved and only type conversion is necessary. + if (!fromElementsOp.getType().getElementType().isa()) { + rewriter.replaceOpWithNewOp( + fromElementsOp, converter->convertType(fromElementsOp.getType()), + adaptor.getOperands()); + + return mlir::success(); + } auto resultTy = fromElementsOp.getResult().getType(); if (converter->isLegal(resultTy)) { @@ -891,11 +926,11 @@ void TFHEToConcretePass::runOnOperation() { mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, mlir::tensor::InsertSliceOp, mlir::tensor::ParallelInsertSliceOp, mlir::tensor::ExpandShapeOp, mlir::tensor::CollapseShapeOp, - mlir::tensor::EmptyOp, mlir::bufferization::AllocTensorOp>( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); + mlir::tensor::EmptyOp, mlir::tensor::FromElementsOp, mlir::tensor::DimOp, + mlir::bufferization::AllocTensorOp>([&](mlir::Operation *op) { + return converter.isLegal(op->getResultTypes()) && + converter.isLegal(op->getOperandTypes()); + }); // rewrite scf for loops if working on illegal types patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< - mlir::tensor::EmptyOp, true>>(&getContext(), converter); + mlir::tensor::EmptyOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::DimOp>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp index 0f002782d3..033afd685d 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/SCF.cpp @@ -4,7 +4,7 @@ // for license information. #include "concretelang/Conversion/Utils/Dialects/SCF.h" -#include "mlir/Transforms/RegionUtils.h" +#include "concretelang/Conversion/Utils/Utils.h" namespace mlir { namespace concretelang { @@ -13,29 +13,16 @@ mlir::LogicalResult TypeConvertingReinstantiationPattern::matchAndRewrite( scf::ForOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - // Create new for loop with empty body, but converted iter args - scf::ForOp newForOp = rewriter.create( - oldOp.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), - adaptor.getStep(), adaptor.getInitArgs(), - [&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {}); - - newForOp->setAttrs(adaptor.getAttributes()); - - // Move operations from old for op to new one - auto &newOperations = newForOp.getBody()->getOperations(); - mlir::Block *oldBody = oldOp.getBody(); - - newOperations.splice(newOperations.begin(), oldBody->getOperations(), - oldBody->begin(), oldBody->end()); + mlir::TypeConverter &typeConverter = *getTypeConverter(); + llvm::SmallVector convertedResultTypes; - // Remap iter args and IV - for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(), - newForOp.getBody()->getArguments())) { - replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair), - newForOp.getRegion()); + if (typeConverter.convertTypes(oldOp.getResultTypes(), convertedResultTypes) + .failed()) { + return mlir::failure(); } - rewriter.replaceOp(oldOp, newForOp.getResults()); + convertOpWithBlocks(oldOp, adaptor.getOperands(), convertedResultTypes, + typeConverter, rewriter); return mlir::success(); } @@ -49,56 +36,10 @@ TypeConvertingReinstantiationPattern::matchAndRewrite( scf::ForallOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - // Create new forall operation with empty body, but converted iter - // args - llvm::SmallVector lbs = getMixedValues( - adaptor.getStaticLowerBound(), adaptor.getDynamicLowerBound(), rewriter); - llvm::SmallVector ubs = getMixedValues( - adaptor.getStaticUpperBound(), adaptor.getDynamicUpperBound(), rewriter); - llvm::SmallVector step = getMixedValues( - adaptor.getStaticStep(), adaptor.getDynamicStep(), rewriter); - - rewriter.setInsertionPoint(oldOp); - - scf::ForallOp newForallOp = rewriter.create( - oldOp.getLoc(), lbs, ubs, step, adaptor.getOutputs(), - adaptor.getMapping()); - - newForallOp->setAttrs(adaptor.getAttributes()); - - // Move operations from old for op to new one - auto &newOperations = newForallOp.getBody()->getOperations(); - mlir::Block *oldBody = oldOp.getBody(); - - newOperations.splice(newOperations.begin(), oldBody->getOperations(), - oldBody->begin(), std::prev(oldBody->end())); - - // Move operations from `scf.forall.in_parallel` terminator of the - // old op to the terminator of the new op - - mlir::scf::InParallelOp oldTerminator = - llvm::dyn_cast(*std::prev(oldBody->end())); - - assert(oldTerminator && "Last operation of `scf.forall` op expected be a " - "`scf.forall.in_parallel` op"); - - mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator(); - - mlir::Block::OpListType &oldTerminatorOps = - oldTerminator.getRegion().getBlocks().begin()->getOperations(); - mlir::Block::OpListType &newTerminatorOps = - newTerminator.getRegion().getBlocks().begin()->getOperations(); - - newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps, - oldTerminatorOps.begin(), oldTerminatorOps.end()); - - // Remap iter args and IV - for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(), - newForallOp.getBody()->getArguments())) { - std::get<0>(argsPair).replaceAllUsesWith(std::get<1>(argsPair)); - } - rewriter.replaceOp(oldOp, newForallOp.getResults()); + convertOpWithBlocks(oldOp, adaptor.getOperands(), + adaptor.getOutputs().getTypes(), *getTypeConverter(), + rewriter); return mlir::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp index 52edb1b588..7e0670acbc 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp @@ -4,8 +4,7 @@ // for license information. #include "concretelang/Conversion/Utils/Dialects/Tensor.h" -#include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/STLExtras.h" +#include "concretelang/Conversion/Utils/Utils.h" namespace mlir { namespace concretelang { @@ -79,26 +78,8 @@ TypeConvertingReinstantiationPattern::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); - rewriter.setInsertionPointAfter(oldOp); - tensor::GenerateOp newGenerateOp = rewriter.create( - oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs()); - - mlir::Block &oldBlock = oldOp.getBody().getBlocks().front(); - mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front(); - auto begin = oldBlock.begin(); - auto nOps = oldBlock.getOperations().size(); - - newBlock.getOperations().splice(newBlock.getOperations().begin(), - oldBlock.getOperations(), begin, - std::next(begin, nOps - 1)); - - for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(), - newGenerateOp.getRegion().getArguments())) { - replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair), - newGenerateOp.getRegion()); - } - - rewriter.replaceOp(oldOp, newGenerateOp.getResult()); + convertOpWithBlocks(oldOp, adaptor.getOperands(), resultTypes, + *getTypeConverter(), rewriter); return mlir::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp index 719a8e2b91..17d4c5904e 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp @@ -6,6 +6,7 @@ #include "concretelang/Conversion/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { namespace concretelang { @@ -56,5 +57,32 @@ mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter, return mlir::concretelang::getCastedMemRef(rewriter, globalRef); } +// Converts an operation `op` with nested blocks using a type +// converter and a conversion pattern rewriter, such that the newly +// created operation uses the operands specified in `newOperands` and +// returns a value of the types `newResultTypes`. +mlir::Operation * +convertOpWithBlocks(mlir::Operation *op, mlir::ValueRange newOperands, + mlir::TypeRange newResultTypes, + mlir::TypeConverter &typeConverter, + mlir::ConversionPatternRewriter &rewriter) { + mlir::OperationState state(op->getLoc(), op->getName().getStringRef(), + newOperands, newResultTypes, op->getAttrs(), + op->getSuccessors()); + + for (Region ®ion : op->getRegions()) { + Region *newRegion = state.addRegion(); + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + TypeConverter::SignatureConversion result(newRegion->getNumArguments()); + (void)typeConverter.convertSignatureArgs(newRegion->getArgumentTypes(), + result); + rewriter.applySignatureConversion(newRegion, result); + } + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + + return newOp; +} } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp index 4e5d0bf137..b2bf4a814f 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -17,15 +18,28 @@ using namespace mlir::concretelang; using namespace mlir; namespace { +outcome::checked +getBufferSize(mlir::MemRefType bufferType); -int64_t getElementTypeSize(mlir::Type elementType) { +outcome::checked +getElementTypeSize(mlir::Type elementType) { if (auto integerType = mlir::dyn_cast(elementType)) { auto width = integerType.getWidth(); return std::ceil((double)width / 8); - } - if (mlir::dyn_cast(elementType)) { + } else if (mlir::dyn_cast(elementType)) { return 8; + } else if (auto memrefType = mlir::dyn_cast(elementType)) { + return getBufferSize(memrefType); + } else if (auto futureType = mlir::dyn_cast(elementType)) { + auto recSize = getElementTypeSize(futureType.getElementType()); + + if (!recSize) + return recSize.error(); + + // FIXME: Hardcoded size of 8 bytes for a pointer + return 8 * recSize.value(); } + return -1; } @@ -33,7 +47,13 @@ outcome::checked getBufferSize(mlir::MemRefType bufferType) { auto shape = bufferType.getShape(); auto elementType = bufferType.getElementType(); - auto elementSize = getElementTypeSize(elementType); + auto maybeElementSize = getElementTypeSize(elementType); + + if (!maybeElementSize) + return maybeElementSize.error(); + + auto elementSize = maybeElementSize.value(); + if (elementSize == -1) return StringError( "allocation of buffer with a non-supported element-type"); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp index 67cf341388..f2a327dc2b 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp @@ -185,8 +185,12 @@ class FHELinalgTilingMarkerPass mlir::ArrayAttr tileAttr = mlir::Builder(&this->getContext()).getI64ArrayAttr(tileSizes); - op->walk([&](mlir::concretelang::FHELinalg::MatMulEintIntOp matmulOp) { - matmulOp.getOperation()->setAttr("tile-sizes", tileAttr); + op->walk([&](mlir::Operation *op) { + if (llvm::isa( + op) || + llvm::isa(op)) { + op->setAttr("tile-sizes", tileAttr); + } }); } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp index bcea750ab6..d8f0ee1676 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp @@ -35,26 +35,54 @@ namespace concretelang { namespace { -class BufferizeRTTypesConverter - : public mlir::bufferization::BufferizeTypeConverter { +class BufferizeRTTypesConverter : public mlir::TypeConverter { +protected: + bufferization::BufferizeTypeConverter btc; + public: BufferizeRTTypesConverter() { + addConversion([&](mlir::Type type) { return btc.convertType(type); }); + + addConversion([&](mlir::RankedTensorType type) { + return mlir::MemRefType::get(type.getShape(), + this->convertType(type.getElementType())); + }); + + addConversion([&](mlir::UnrankedTensorType type) { + return mlir::UnrankedMemRefType::get( + this->convertType(type.getElementType()), 0); + }); + + addConversion([&](mlir::MemRefType type) { + return mlir::MemRefType::get(type.getShape(), + this->convertType(type.getElementType()), + type.getLayout(), type.getMemorySpace()); + }); + + addConversion([&](mlir::UnrankedMemRefType type) { + return mlir::UnrankedMemRefType::get( + this->convertType(type.getElementType()), type.getMemorySpace()); + }); + addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( - this->convertType(type.dyn_cast() - .getElementType())); + this->convertType(type.getElementType())); }); + addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( - this->convertType(type.dyn_cast() - .getElementType())); + this->convertType(type.getElementType())); }); + addConversion([&](mlir::FunctionType type) { SignatureConversion result(type.getNumInputs()); mlir::SmallVector newResults; + if (failed(this->convertSignatureArgs(type.getInputs(), result)) || - failed(this->convertTypes(type.getResults(), newResults))) + failed(this->convertTypes(type.getResults(), newResults))) { return type; + } + return mlir::FunctionType::get(type.getContext(), result.getConvertedTypes(), newResults); }); @@ -95,6 +123,26 @@ struct BufferizeDataflowTaskOpsPass mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, typeConverter); + patterns.add, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::LoadOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::StoreOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::CopyOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::memref::SubViewOp, true>>(&getContext(), + typeConverter); + + target.addDynamicallyLegalOp( + [&](mlir::Operation *op) { + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index b76b5c1d29..574cd38e4e 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -568,11 +568,22 @@ struct FinalizeTaskCreationPass .setLayout(AffineMapAttr::get( builder.getMultiDimIdentityMap(rank))); } + + llvm::SmallVector dynamicDimSizes; + + for (auto dimSizeIt : llvm::enumerate(mrType.getShape())) { + if (mlir::ShapedType::isDynamic(dimSizeIt.value())) { + mlir::memref::DimOp dimOp = builder.create( + val.getLoc(), val, dimSizeIt.index()); + dynamicDimSizes.push_back(dimOp.getResult()); + } + } + // We need to make a copy of this MemRef to allow deallocation // based on refcounting - Value newval = - builder.create(val.getLoc(), mrType) - .getResult(); + mlir::memref::AllocOp newval = builder.create( + val.getLoc(), mrType, dynamicDimSizes); + builder.create(val.getLoc(), val, newval); clone = builder.create(op.getLoc(), builder.getI64IntegerAttr(1)); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt index 993e05688d..5168d47ea3 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt @@ -1,9 +1,11 @@ add_mlir_dialect_library( RTDialectTransforms BufferizableOpInterfaceImpl.cpp + HoistAwaitFuturePass.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/RT DEPENDS + RTTransformsIncGen mlir-headers LINK_LIBS PUBLIC diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp new file mode 100644 index 0000000000..f5e9784054 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp @@ -0,0 +1,259 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include + +#include + +#include +#include + +namespace { +struct HoistAwaitFuturePass + : public HoistAwaitFuturePassBase { + // Checks if all values of `a` are sizes of a non-dynamic dimensions + bool allStatic(llvm::ArrayRef a) { + return llvm::all_of( + a, [](int64_t r) { return !mlir::ShapedType::isDynamic(r); }); + } + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + + llvm::SmallVector opsToErase; + + func.walk([&](mlir::concretelang::RT::AwaitFutureOp awaitFutureOp) { + // Make sure there are no other consumers that rely on the + // synchronization + if (!awaitFutureOp.getResult().hasOneUse()) + return; + + mlir::scf::ForallOp forallOp = + llvm::dyn_cast(awaitFutureOp->getParentOp()); + + if (!forallOp) + return; + + mlir::tensor::ParallelInsertSliceOp parallelInsertSliceOp = + llvm::dyn_cast( + awaitFutureOp.getResult().getUses().begin()->getOwner()); + + if (!parallelInsertSliceOp) + return; + + // Make sure that the original tensor into which the + // synchronized values are inserted is a region out argument of + // the forall op and thus being written to concurrently + mlir::Value dst = parallelInsertSliceOp.getDest(); + + if (!llvm::any_of(forallOp.getRegionOutArgs(), + [=](mlir::Value output) { return output == dst; })) + return; + + // Currently, the tensor storing the futures must have a static + // shape, so only loops with static trip counts are supported + if (!(allStatic(forallOp.getStaticLowerBound()) && + allStatic(forallOp.getStaticUpperBound()) && + allStatic(forallOp.getStaticStep()))) + return; + + llvm::SmallVector tripCounts; + + for (auto [lb, ub, step] : llvm::zip_equal(forallOp.getStaticLowerBound(), + forallOp.getStaticUpperBound(), + forallOp.getStaticStep())) { + tripCounts.push_back( + mlir::concretelang::getStaticTripCount(lb, ub, step)); + } + + mlir::IRRewriter rewriter(&getContext()); + rewriter.setInsertionPoint(forallOp); + + mlir::Value tensorOfFutures = rewriter.create( + forallOp.getLoc(), tripCounts, awaitFutureOp.getInput().getType()); + + // Assemble the list of shared outputs that are to be preserved + // after the output storing the results of the `RT.await_future` + // has been removed + llvm::SmallVector newOutputs; + mlir::Value tensorOfValues; + size_t i = 0; + size_t oldResultIdx; + for (auto [output, regionOutArg] : llvm::zip_equal( + forallOp.getOutputs(), forallOp.getRegionOutArgs())) { + if (regionOutArg != dst) { + newOutputs.push_back(output); + } else { + tensorOfValues = output; + oldResultIdx = i; + } + + i++; + } + + newOutputs.push_back(tensorOfFutures); + + // Create a new forall loop with the same shared outputs except + // for the one previously storing the contents of the + // `RT.await_future` ops is replaced with a tensor of futures + rewriter.setInsertionPointAfter(forallOp); + mlir::scf::ForallOp newForallOp = rewriter.create( + forallOp.getLoc(), forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOutputs, + std::nullopt); + + // Move all operations from the old forall op to the new one + auto &newOperations = newForallOp.getBody()->getOperations(); + mlir::Block *oldBody = forallOp.getBody(); + + newOperations.splice(newOperations.begin(), oldBody->getOperations(), + oldBody->begin(), std::prev(oldBody->end())); + + // Wrap future in a tensor of one element, so that it can be + // stored in the new shared output tensor of futures using + // `tensor.parallel_insert_slice` + rewriter.setInsertionPointAfter(awaitFutureOp); + mlir::Value futureAsTensor = + rewriter.create( + awaitFutureOp.getLoc(), + mlir::ValueRange{awaitFutureOp.getInput()}); + + // Move all operations from the old `scf.forall.in_parallel` + // terminator to the new one + mlir::scf::InParallelOp oldTerminator = forallOp.getTerminator(); + mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator(); + + mlir::Block::OpListType &oldTerminatorOps = + oldTerminator.getRegion().getBlocks().begin()->getOperations(); + mlir::Block::OpListType &newTerminatorOps = + newTerminator.getRegion().getBlocks().begin()->getOperations(); + + newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps, + oldTerminatorOps.begin(), oldTerminatorOps.end()); + + // Remap IVs and out args + for (auto [oldIV, newIV] : llvm::zip(forallOp.getInductionVars(), + newForallOp.getInductionVars())) { + oldIV.replaceAllUsesWith(newIV); + } + + { + size_t offs = 0; + for (auto it : llvm::enumerate(forallOp.getRegionOutArgs())) { + mlir::Value oldRegionOutArg = it.value(); + + if (oldRegionOutArg != dst) { + oldRegionOutArg.replaceAllUsesWith( + newForallOp.getRegionOutArgs()[it.index() - offs]); + } else { + offs++; + } + } + } + + // Create new `tensor.parallel_inset_slice` operation inserting + // the future into the tensor of futures + llvm::SmallVector ones(tripCounts.size(), + rewriter.getI64IntegerAttr(1)); + + mlir::Value tensorOfFuturesRegionOutArg = + newForallOp.getRegionOutArgs().back(); + + mlir::ImplicitLocOpBuilder ilob(parallelInsertSliceOp.getLoc(), rewriter); + + rewriter.setInsertionPointAfter(parallelInsertSliceOp); + rewriter.create( + parallelInsertSliceOp.getLoc(), futureAsTensor, + tensorOfFuturesRegionOutArg, + mlir::getAsOpFoldResult(mlir::concretelang::normalizeInductionVars( + ilob, newForallOp.getInductionVars(), + newForallOp.getMixedLowerBound(), newForallOp.getMixedStep())), + ones, ones); + + // Create a new forall loop, that invokes `RT.await_future` on + // all futures stored in the tensor of futures and writes the + // contents into the otiginal tensor with the results + rewriter.setInsertionPointAfter(newForallOp); + mlir::scf::ForallOp syncForallOp = rewriter.create( + forallOp.getLoc(), forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), + mlir::ValueRange{tensorOfValues}, std::nullopt); + + mlir::Value resultTensorOfFutures = newForallOp.getResults().back(); + + rewriter.setInsertionPointToStart(syncForallOp.getBody()); + mlir::Value extractedFuture = rewriter.create( + awaitFutureOp.getLoc(), resultTensorOfFutures, + syncForallOp.getInductionVars()); + mlir::concretelang::RT::AwaitFutureOp newAwaitFutureOp = + rewriter.create( + awaitFutureOp.getLoc(), awaitFutureOp.getResult().getType(), + extractedFuture); + + mlir::IRMapping syncMapping; + + for (auto [oldIV, newIV] : + llvm::zip_equal(newForallOp.getInductionVars(), + syncForallOp.getInductionVars())) { + syncMapping.map(oldIV, newIV); + } + + syncMapping.map(dst, syncForallOp.getOutputBlockArguments().back()); + syncMapping.map(parallelInsertSliceOp.getSource(), + newAwaitFutureOp.getResult()); + + mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator(); + rewriter.setInsertionPointToStart(syncTerminator.getBody()); + rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping); + + // Replace uses of the results of the original forall loop with: + // either the corresponding result from the new forall loop if + // this is a result unrelated to the futures or with the result + // of the forall loop synchronizing the futures + { + size_t offs = 0; + for (size_t i = 0; i < forallOp.getNumResults(); i++) { + if (i == oldResultIdx) { + forallOp.getResult(i).replaceAllUsesWith(syncForallOp.getResult(0)); + offs = 1; + } else { + forallOp.getResult(i).replaceAllUsesWith( + newForallOp.getResult(i - offs)); + } + } + } + + // Replace the use of the shared output with the results of the + // original forall loop with the tensor outside of the loop so + // that there are no more references to values that were local + // to the original forall loop, enabling safe erasing of the old + // operations within the original forall loop + dst.replaceAllUsesWith( + forallOp.getOutputs().drop_front(oldResultIdx).front()); + parallelInsertSliceOp->erase(); + awaitFutureOp.erase(); + + // Defer erasing the original parallel loop that contained the + // `RT.await_future` operation until later in order to not + // confuse the walk relying on the parent operation + opsToErase.push_back(forallOp); + }); + + for (mlir::Operation *op : opsToErase) + op->erase(); + } +}; +} // namespace + +namespace mlir { +namespace concretelang { +std::unique_ptr> createHoistAwaitFuturePass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index 0c93dc0931..31a7cdd04d 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -459,6 +459,9 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, return StreamStringError("Dataflow parallelization failed"); } + if (target == Target::FHE_DF_PARALLELIZED) + return std::move(res); + if (mlir::concretelang::pipeline::lowerLinalgToLoops( mlirContext, module, enablePass, loopParallelize) .failed()) { diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index c1f6187b1a..210d2ea03d 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -44,6 +44,7 @@ #include "concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h" #include "concretelang/Dialect/FHELinalg/Transforms/Tiling.h" #include "concretelang/Dialect/RT/Analysis/Autopar.h" +#include "concretelang/Dialect/RT/Transforms/Passes.h" #include "concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h" #include "concretelang/Dialect/TFHE/Transforms/Transforms.h" #include "concretelang/Support/CompilerEngine.h" @@ -183,6 +184,8 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, pm, mlir::concretelang::createBuildDataflowTaskGraphPass(), enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass); + addPotentiallyNestedPass(pm, mlir::concretelang::createHoistAwaitFuturePass(), + enablePass); return pm.run(module.getOperation()); } diff --git a/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp b/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp index 01bec55e56..2d30399492 100644 --- a/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp +++ b/compilers/concrete-compiler/compiler/lib/Transforms/Batching.cpp @@ -516,18 +516,10 @@ buildNormalizedIndexes(mlir::PatternRewriter &rewriter, rewriter.setInsertionPointToStart(innermost.getBody()); for (mlir::scf::ForOp forOp : nest) { + mlir::ImplicitLocOpBuilder ilob(forOp.getLoc(), rewriter); - mlir::Value idxShifted = - isConstantIndexValue(forOp.getLowerBound(), 0) - ? forOp.getInductionVar() - : rewriter.create(innermost.getLoc(), - forOp.getInductionVar(), - forOp.getLowerBound()); - mlir::Value idx = - isConstantIndexValue(forOp.getStep(), 1) - ? idxShifted - : rewriter.create( - innermost.getLoc(), idxShifted, forOp.getStep()); + mlir::Value idx = normalizeInductionVar( + ilob, forOp.getInductionVar(), forOp.getLowerBound(), forOp.getStep()); res.push_back(idx); } diff --git a/compilers/concrete-compiler/compiler/src/main.cpp b/compilers/concrete-compiler/compiler/src/main.cpp index b5cfefdf00..a0abdd06f5 100644 --- a/compilers/concrete-compiler/compiler/src/main.cpp +++ b/compilers/concrete-compiler/compiler/src/main.cpp @@ -51,6 +51,7 @@ namespace optimizer = mlir::concretelang::optimizer; enum Action { ROUND_TRIP, DUMP_FHE, + DUMP_FHE_DF_PARALLELIZED, DUMP_FHE_LINALG_GENERIC, DUMP_FHE_NO_LINALG, DUMP_TFHE, @@ -142,6 +143,9 @@ static llvm::cl::opt action( llvm::cl::values(clEnumValN(Action::DUMP_FHE_LINALG_GENERIC, "dump-fhe-linalg-generic", "Lower FHELinalg to Linalg and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_FHE_DF_PARALLELIZED, + "dump-fhe-df-parallelized", + "Dump result after data-flow parallelization")), llvm::cl::values(clEnumValN(Action::DUMP_FHE_NO_LINALG, "dump-fhe-no-linalg", "Lower FHELinalg to FHE and dump result")), @@ -580,6 +584,9 @@ mlir::LogicalResult processInputBuffer( case Action::DUMP_FHE_LINALG_GENERIC: target = mlir::concretelang::CompilerEngine::Target::FHE_LINALG_GENERIC; break; + case Action::DUMP_FHE_DF_PARALLELIZED: + target = mlir::concretelang::CompilerEngine::Target::FHE_DF_PARALLELIZED; + break; case Action::DUMP_FHE_NO_LINALG: target = mlir::concretelang::CompilerEngine::Target::FHE_NO_LINALG; break; diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir index 8ca00b7244..069b514457 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir @@ -45,3 +45,31 @@ func.func @tiled_one_big_tile(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) %0 = "FHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [0,0,4] } : (tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!FHE.eint<6>> return %0 : tensor<8x2x!FHE.eint<6>> } + +// ----- + +// CHECK: %[[V1:.*]] = scf.forall (%[[Varg2:.*]], %[[Varg3:.*]], %[[Varg4:.*]]) in (1, 1, 2) shared_outs(%[[Varg5:.*]] = %[[V0:.*]]) -> (tensor<2x3x4x!FHE.eint<2>>) { +// CHECK-NEXT: %[[V2:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: %[[V3:.*]] = affine.apply #map1(%[[Varg3]]) +// CHECK-NEXT: %[[V4:.*]] = affine.apply #map(%[[Varg4]]) +// CHECK-NEXT: %[[V5:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: %[[V6:.*]] = affine.apply #map1(%[[Varg3]]) +// CHECK-NEXT: %[[V7:.*]] = affine.apply #map(%[[Varg4]]) +// CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V2]], %[[V3]], %[[V4]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>> +// CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg5]]{{\[}}%[[V5]], %[[V6]], %[[V7]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>> +// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map2{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor<2x3x2x!FHE.eint<2>>) outs(%[[Vextracted_slice_0]] : tensor<2x3x2x!FHE.eint<2>>) attrs = {"tile-sizes" = {{\[2, 3, 2\]}}} { +// CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<2>, %[[Vout:.*]]: !FHE.eint<2>): +// CHECK-NEXT: %[[V12:.*]] = "FHE.apply_lookup_table"(%[[Vin]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> +// CHECK-NEXT: linalg.yield %[[V12]] : !FHE.eint<2> +// CHECK-NEXT: } -> tensor<2x3x2x!FHE.eint<2>> +// CHECK-NEXT: %[[V9:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: %[[V10:.*]] = affine.apply #map1(%[[Varg3]]) +// CHECK-NEXT: %[[V11:.*]] = affine.apply #map(%[[Varg4]]) +// CHECK-NEXT: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %[[V8]] into %[[Varg5]]{{\[}}%[[V9]], %[[V10]], %[[V11]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x2x!FHE.eint<2>> into tensor<2x3x4x!FHE.eint<2>> +// CHECK-NEXT: } +// CHECK-NEXT: } +func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> { + %1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1) { "tile-sizes" = [2,3,2] } : (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!FHE.eint<2>>) + return %1: tensor<2x3x4x!FHE.eint<2>> +} diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future.mlir new file mode 100644 index 0000000000..e6748d3224 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future.mlir @@ -0,0 +1,20 @@ +// RUN: concretecompiler --action=dump-fhe-df-parallelized %s --optimizer-strategy=dag-mono --parallelize | FileCheck %s +// RUN: concretecompiler --action=dump-llvm-ir %s --optimizer-strategy=dag-mono --parallelize +// RUN: concretecompiler --action=dump-llvm-ir %s --optimizer-strategy=dag-multi --parallelize + +// CHECK: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %from_elements into %arg3[%arg2] [1] [1] : tensor<1x!RT.future>>> into tensor<4x!RT.future>>> +// CHECK-NEXT: } +// +// CHECK: %[[res:.*]] = scf.forall (%[[arg:.*]]) in (4) shared_outs(%[[so:.*]] = %[[init:.*]]) -> (tensor<8x9x4x!FHE.eint<6>>) { +// CHECK-NEXT: %[[extracted:.*]] = tensor.extract %4[%[[arg]]] : tensor<4x!RT.future>>> +// CHECK-NEXT: %[[awaitres:.*]] = "RT.await_future"(%[[extracted]]) : (!RT.future>>) -> tensor<8x9x!FHE.eint<6>> +// CHECK-NEXT: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %[[awaitres]] into %[[so]][0, 0, %[[arg]]] [8, 9, 1] [1, 1, 1] : tensor<8x9x!FHE.eint<6>> into tensor<8x9x4x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +func.func @main(%a: tensor<8x7x!FHE.eint<6>>, %b: tensor<7x9xi7>) -> tensor<8x9x!FHE.eint<6>>{ + %0 = "FHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [0, 0, 2] } : (tensor<8x7x!FHE.eint<6>>, tensor<7x9xi7>) -> tensor<8x9x!FHE.eint<6>> + return %0 : tensor<8x9x!FHE.eint<6>> +} diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml index 005efce9fd..8f448c6426 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml @@ -2152,6 +2152,24 @@ tests: - tensor: [16,21,44,57,12,23,30,39,58,55,16,21,44,57,12,23] shape: [8,2] +--- +description: tiled_matmul_eint_int_1_1_7_partial_tiles +program: | + func.func @main(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8x2x!FHE.eint<6>> { + %0 = "FHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [0,0,7] } : (tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!FHE.eint<6>> + return %0 : tensor<8x2x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2] + shape: [8,4] + - tensor: [1,2,3,4,3,1,0,2] + shape: [4,2] + width: 8 + outputs: + - tensor: [16,21,44,57,12,23,30,39,58,55,16,21,44,57,12,23] + shape: [8,2] + --- description: extract_slice_parametric_2x2 program: |