Skip to content

Commit

Permalink
feat(compiler): Add pass hoisting RT.await_future out of scf.forall l…
Browse files Browse the repository at this point in the history
…oops

The new pass hoists `RT.await_future` operations whose results are
yielded by scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.

E.g., the following IR:

```
scf.forall (%arg) in (16)
  shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
  -> (tensor<...>, tensor<...>)
{
  ...
  %rph = "RT.build_return_ptr_placeholder"() :
    () -> !RT.rtptr<!RT.future<tensor<...>>>
  "RT.create_async_task"(..., %rph, ...) { ... } : ...
  %future = "RT.deref_return_ptr_placeholder"(%rph) :
    (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
  %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
  ...
  scf.forall.in_parallel {
    ...
    tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
      tensor<...> into tensor<...>
    ...
  }
}
```

is transformed into:

```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>

scf.forall (%arg) in (16)
  shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
  -> (tensor<...>, tensor<...>)
{
  ...
  %rph = "RT.build_return_ptr_placeholder"() :
    () -> !RT.rtptr<!RT.future<tensor<...>>>
  "RT.create_async_task"(..., %rph, ...) { ... } : ...
  %future = "RT.deref_return_ptr_placeholder"(%rph) :
    (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
  %wrappedfuture = tensor.from_elements %future :
    tensor<1x!RT.future<tensor<...>>>
  ...
  scf.forall.in_parallel {
    ...
    tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
      tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
    ...
  }
}

scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
  %future = tensor.extract %tensoroffutures[%arg] :
    tensor<4x!RT.future<tensor<...>>>
  %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
      tensor<...> into tensor<...>
  }
}
```
  • Loading branch information
andidr committed Apr 3, 2024
1 parent 0120fda commit 7b07f30
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder,
mlir::Value iv, mlir::OpFoldResult lb,
mlir::OpFoldResult step);

llvm::SmallVector<mlir::Value>
normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder,
mlir::ValueRange ivs,
llvm::ArrayRef<mlir::OpFoldResult> lbs,
llvm::ArrayRef<mlir::OpFoldResult> steps);

} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name RT)
add_public_tablegen_target(RTTransformsIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H
#define CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"

#include "concretelang/Dialect/RT/IR/RTDialect.h"

#define GEN_PASS_CLASSES
#include "concretelang/Dialect/RT/Transforms/Passes.h.inc"

namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<func::FuncOp>> createHoistAwaitFuturePass();
} // namespace concretelang
} // namespace mlir

#endif // CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef MLIR_DIALECT_RT_TRANSFORMS_PASSES
#define MLIR_DIALECT_RT_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def HoistAwaitFuturePass : Pass<"hoist-await-future", "mlir::func::FuncOp"> {
let summary = "Hoists `RT.await_future` operations whose results are yielded "
"by `scf.forall` operations out of the loops";
let description = [{
Hoists `RT.await_future` operations whose results are yielded by
scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.

E.g., the following IR:

```
scf.forall (%arg) in (16)
shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
tensor<...> into tensor<...>
...
}
}
```

is transformed into:

```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>

scf.forall (%arg) in (16)
shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%wrappedfuture = tensor.from_elements %future :
tensor<1x!RT.future<tensor<...>>>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
...
}
}

scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
%future = tensor.extract %tensoroffutures[%arg] :
tensor<4x!RT.future<tensor<...>>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
scf.forall.in_parallel {
tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
tensor<...> into tensor<...>
}
}
```
}];
let constructor = "mlir::concretelang::createHoistAwaitFuturePass()";
let dependentDialects = [
"mlir::func::FuncDialect", "mlir::scf::SCFDialect",
"mlir::tensor::TensorDialect", "mlir::concretelang::RT::RTDialect"
];
}

#endif // MLIR_DIALECT_RT_TRANSFORMS_PASSES
14 changes: 14 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,5 +472,19 @@ mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder,
return normalizedIV;
}

llvm::SmallVector<mlir::Value>
normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder,
mlir::ValueRange ivs,
llvm::ArrayRef<mlir::OpFoldResult> lbs,
llvm::ArrayRef<mlir::OpFoldResult> steps) {
llvm::SmallVector<mlir::Value> normalizedIVs;

for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
normalizedIVs.push_back(normalizeInductionVar(builder, iv, lb, step));
}

return normalizedIVs;
}

} // namespace concretelang
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ void TFHEKeyNormalizationPass::runOnOperation() {
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::DimOp>(
target, typeConverter);

patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::tensor::ParallelInsertSliceOp>>(&getContext(), typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::tensor::ParallelInsertSliceOp>(target, typeConverter);

patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
conversion::TypeConverter>>(
&getContext(), typeConverter);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
add_mlir_dialect_library(
RTDialectTransforms
BufferizableOpInterfaceImpl.cpp
HoistAwaitFuturePass.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/RT
DEPENDS
RTTransformsIncGen
mlir-headers
LINK_LIBS
PUBLIC
Expand Down
Loading

0 comments on commit 7b07f30

Please sign in to comment.