Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass hoisting RT.await_future out of scf.forall loops #748

Merged
merged 17 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
6b73879
feat(compiler): Add support for tiling of FHELinalg.apply_lookup_table
andidr Feb 26, 2024
5c81882
feat(compiler): Add new action dump-fhe-df-parallelized
andidr Feb 29, 2024
6c1cef1
refactor(compiler): Factor normalization of loop IVs from Batching-sp…
andidr Mar 14, 2024
93bb849
refactor(compiler): Use reinstantiating conversion patterns for RT op…
andidr Apr 3, 2024
74efe8b
feat(compiler): Declare RT futures usable as element types for memrefs
andidr Apr 3, 2024
a855e2b
refactor(compiler): Make type conversion in scalar FHE to TFHE conver…
andidr Apr 3, 2024
9216c61
refactor(compiler): Make type conversion in TFHE global parametrizati…
andidr Apr 3, 2024
68d0014
feat(compiler): Support non-ciphertext types in TFHE to Concrete conv…
andidr Apr 3, 2024
f668e82
refactor(compiler): Make type conversion in RT task bufferization rec…
andidr Apr 3, 2024
0c7e3a3
feat(compiler): Add support for nested Memrefs in memory usage estimator
andidr Apr 3, 2024
48d919b
feat(compiler): Add support for tensor.{from_elements,dim} operations…
andidr Apr 3, 2024
3ad3dcb
refactor(compiler): Use signature conversion for conversion of ops wi…
andidr Apr 3, 2024
999c9a9
feat(compiler): Add support for dynamically-sized memrefs in lowering…
andidr Apr 3, 2024
fd513f1
feat(compiler): Add support for various memref operations for RT task…
andidr Apr 3, 2024
d620fa9
feat(compiler): Add pass hoisting RT.await_future out of scf.forall l…
andidr Mar 14, 2024
7cf5483
test(compiler): Add tests for tiling generating partial tiles
andidr Apr 4, 2024
f506f5f
test(compiler): Add check tests for pass hoisting RT.await_future ope…
andidr Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define CONCRETELANG_ANALYSIS_STATIC_LOOPS_H

#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/ImplicitLocOpBuilder.h>

namespace mlir {
namespace concretelang {
Expand Down Expand Up @@ -60,6 +61,16 @@ bool isConstantIndexValue(mlir::Value v);
int64_t getConstantIndexValue(mlir::Value v);
bool isConstantIndexValue(mlir::Value v, int64_t i);

mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder,
mlir::Value iv, mlir::OpFoldResult lb,
mlir::OpFoldResult step);

llvm::SmallVector<mlir::Value>
normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder,
mlir::ValueRange ivs,
llvm::ArrayRef<mlir::OpFoldResult> lbs,
llvm::ArrayRef<mlir::OpFoldResult> steps);

} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#ifndef CONCRETELANG_CONVERSION_RTOPCONVERTER_H_
#define CONCRETELANG_CONVERSION_RTOPCONVERTER_H_

#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
#include "concretelang/Conversion/Utils/Legality.h"
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -20,25 +20,25 @@ populateWithRTTypeConverterPatterns(mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target,
mlir::TypeConverter &converter) {
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::DataflowTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::DataflowYieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::CreateAsyncTaskOp, true>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(
patterns.getContext(), converter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
mlir::Location loc,
mlir::ArrayAttr arrAttr);

mlir::Operation *convertOpWithBlocks(mlir::Operation *op,
mlir::ValueRange newOperands,
mlir::TypeRange newResultTypes,
mlir::TypeConverter &typeConverter,
mlir::ConversionPatternRewriter &rewriter);

} // namespace concretelang
} // namespace mlir
#endif
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ include "mlir/IR/BuiltinTypes.td"
class RT_Type<string name, list<Trait> traits = []> :
TypeDef<RT_Dialect, name, traits> { }

def RT_Future : RT_Type<"Future"> {
def RT_Future : RT_Type<"Future", [MemRefElementTypeInterface]> {
let mnemonic = "future";

let summary = "Future with a parameterized element type";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name RT)
add_public_tablegen_target(RTTransformsIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H
#define CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"

#include "concretelang/Dialect/RT/IR/RTDialect.h"

#define GEN_PASS_CLASSES
#include "concretelang/Dialect/RT/Transforms/Passes.h.inc"

namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<func::FuncOp>> createHoistAwaitFuturePass();
} // namespace concretelang
} // namespace mlir

#endif // CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef MLIR_DIALECT_RT_TRANSFORMS_PASSES
#define MLIR_DIALECT_RT_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def HoistAwaitFuturePass : Pass<"hoist-await-future", "mlir::func::FuncOp"> {
let summary = "Hoists `RT.await_future` operations whose results are yielded "
"by `scf.forall` operations out of the loops";
let description = [{
Hoists `RT.await_future` operations whose results are yielded by
scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.

E.g., the following IR:

```
scf.forall (%arg) in (16)
shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
tensor<...> into tensor<...>
...
}
}
```

is transformed into:

```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>

scf.forall (%arg) in (16)
shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%wrappedfuture = tensor.from_elements %future :
tensor<1x!RT.future<tensor<...>>>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
...
}
}

scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
%future = tensor.extract %tensoroffutures[%arg] :
tensor<4x!RT.future<tensor<...>>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
scf.forall.in_parallel {
antoniupop marked this conversation as resolved.
Show resolved Hide resolved
tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
tensor<...> into tensor<...>
}
}
```
}];
let constructor = "mlir::concretelang::createHoistAwaitFuturePass()";
let dependentDialects = [
"mlir::func::FuncDialect", "mlir::scf::SCFDialect",
"mlir::tensor::TensorDialect", "mlir::concretelang::RT::RTDialect"
];
}

#endif // MLIR_DIALECT_RT_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ class CompilerEngine {
/// from the Linalg dialect
FHE_LINALG_GENERIC,

/// Read sources and lower all the FHELinalg operations to FHE
/// operations, dump after data-flow parallelization
FHE_DF_PARALLELIZED,

/// Read sources and lower all the FHELinalg operations to FHE operations
/// and scf loops
FHE_NO_LINALG,
Expand Down
40 changes: 40 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Arith/Utils/Utils.h>
#include <mlir/Dialect/Utils/StaticValueUtils.h>

#include <concretelang/Analysis/StaticLoops.h>
#include <optional>
Expand Down Expand Up @@ -446,5 +448,43 @@ bool isConstantIndexValue(mlir::Value v, int64_t i) {
return isConstantIndexValue(v) && getConstantIndexValue(v) == i;
}

// Returns a `Value` corresponding to `iv`, normalized to the lower
// bound `lb` and step `step` of a loop (i.e., (iv - lb) / step).
mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder,
mlir::Value iv, mlir::OpFoldResult lb,
mlir::OpFoldResult step) {
std::optional<int64_t> lbInt = mlir::getConstantIntValue(lb);
std::optional<int64_t> stepInt = mlir::getConstantIntValue(step);

mlir::Value idxShifted = lbInt.has_value() && *lbInt == 0
? iv
: builder.create<mlir::arith::SubIOp>(
iv, mlir::getValueOrCreateConstantIndexOp(
builder, builder.getLoc(), lb));

mlir::Value normalizedIV =
stepInt.has_value() && *stepInt == 1
? idxShifted
: builder.create<mlir::arith::DivSIOp>(
idxShifted, mlir::getValueOrCreateConstantIndexOp(
builder, builder.getLoc(), step));

return normalizedIV;
}

llvm::SmallVector<mlir::Value>
normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder,
mlir::ValueRange ivs,
llvm::ArrayRef<mlir::OpFoldResult> lbs,
llvm::ArrayRef<mlir::OpFoldResult> steps) {
llvm::SmallVector<mlir::Value> normalizedIVs;

for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
normalizedIVs.push_back(normalizeInductionVar(builder, iv, lb, step));
}

return normalizedIVs;
}

} // namespace concretelang
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
outs, maps, iteratorTypes, doc,
call, bodyBuilder);

if (lutOp->hasAttr("tile-sizes"))
genericOp->setAttr("tile-sizes", lutOp->getAttr("tile-sizes"));

rewriter.replaceOp(lutOp, {genericOp.getResult(0)});

return ::mlir::success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,45 +84,38 @@ TFHE::GLWECipherTextType convertEncrypted(mlir::MLIRContext *context,
return TFHE::GLWECipherTextType::get(context, TFHE::GLWESecretKey());
}

/// Converts `Tensor<FHE::AnyEncryptedInteger>` into a
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate.
/// Otherwise return the input type.
mlir::Type
maybeConvertEncryptedTensor(mlir::MLIRContext *context,
mlir::RankedTensorType maybeEncryptedTensor) {
if (!maybeEncryptedTensor.getElementType().isa<FHE::FheIntegerInterface>()) {
return (mlir::Type)(maybeEncryptedTensor);
}
auto currentShape = maybeEncryptedTensor.getShape();
return mlir::RankedTensorType::get(
currentShape,
TFHE::GLWECipherTextType::get(context, TFHE::GLWESecretKey()));
}

/// Converts any encrypted type to `TFHE::GlweCiphetext` if the
/// input type is appropriate. Otherwise return the input type.
mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t) {
if (auto eint = t.dyn_cast<FHE::FheIntegerInterface>())
return convertEncrypted(context, eint);
return t;
}

/// The type converter used to convert `FHE` to `TFHE` types using the scalar
/// strategy.
class TypeConverter : public mlir::TypeConverter {

public:
TypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](mlir::FunctionType type) {
llvm::SmallVector<mlir::Type> inputTypes;
llvm::SmallVector<mlir::Type> resultTypes;

for (mlir::Type inputType : type.getInputs())
inputTypes.push_back(this->convertType(inputType));

for (mlir::Type resultType : type.getResults())
resultTypes.push_back(this->convertType(resultType));

mlir::Type res =
mlir::FunctionType::get(type.getContext(), inputTypes, resultTypes);

return res;
});
addConversion([](FHE::FheIntegerInterface type) {
return convertEncrypted(type.getContext(), type);
});
addConversion([](FHE::EncryptedBooleanType type) {
return TFHE::GLWECipherTextType::get(type.getContext(),
TFHE::GLWESecretKey());
});
addConversion([](mlir::RankedTensorType type) {
return maybeConvertEncryptedTensor(type.getContext(), type);
addConversion([&](mlir::RankedTensorType type) {
return mlir::RankedTensorType::get(
type.getShape(), this->convertType(type.getElementType()));
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
Expand Down Expand Up @@ -860,6 +853,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
mlir::scf::ForallOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::EmptyOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::FromElementsOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::DimOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ParallelInsertSliceOp, true>>(&getContext(),
converter);
Expand All @@ -882,6 +879,8 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::EmptyOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::tensor::ParallelInsertSliceOp>(target, converter);

Expand Down Expand Up @@ -917,6 +916,9 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::EmptyOp>(
target, converter);

mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::DimOp>(
target, converter);

mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target,
converter);

Expand Down
Loading
Loading