-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(compiler): Add pass hoisting RT.await_future out of scf.forall l…
…oops The new pass hoists `RT.await_future` operations whose results are yielded by scf.forall operations out of the loops in order to avoid over-synchronization of data-flow tasks. E.g., the following IR: ``` scf.forall (%arg) in (16) shared_outs(%o1 = %sometensor, %o2 = %someothertensor) -> (tensor<...>, tensor<...>) { ... %rph = "RT.build_return_ptr_placeholder"() : () -> !RT.rtptr<!RT.future<tensor<...>>> "RT.create_async_task"(..., %rph, ...) { ... } : ... %future = "RT.deref_return_ptr_placeholder"(%rph) : (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>> %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...> ... scf.forall.in_parallel { ... tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] : tensor<...> into tensor<...> ... } } ``` is transformed into: ``` %tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>> scf.forall (%arg) in (16) shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor) -> (tensor<...>, tensor<...>) { ... %rph = "RT.build_return_ptr_placeholder"() : () -> !RT.rtptr<!RT.future<tensor<...>>> "RT.create_async_task"(..., %rph, ...) { ... } : ... %future = "RT.deref_return_ptr_placeholder"(%rph) : (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>> %wrappedfuture = tensor.from_elements %future : tensor<1x!RT.future<tensor<...>>> ... scf.forall.in_parallel { ... tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] : tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>> ... } } scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) { %future = tensor.extract %tensoroffutures[%arg] : tensor<4x!RT.future<tensor<...>>> %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...> scf.forall.in_parallel { tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] : tensor<...> into tensor<...> } } ```
- Loading branch information
Showing
10 changed files
with
400 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
add_subdirectory(Analysis) | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
3 changes: 3 additions & 0 deletions
3
...lers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name RT) | ||
add_public_tablegen_target(RTTransformsIncGen) |
25 changes: 25 additions & 0 deletions
25
compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama | ||
// Exceptions. See | ||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt | ||
// for license information. | ||
|
||
#ifndef CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H | ||
#define CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
#include "concretelang/Dialect/RT/IR/RTDialect.h" | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "concretelang/Dialect/RT/Transforms/Passes.h.inc" | ||
|
||
namespace mlir { | ||
namespace concretelang { | ||
std::unique_ptr<OperationPass<func::FuncOp>> createHoistAwaitFuturePass(); | ||
} // namespace concretelang | ||
} // namespace mlir | ||
|
||
#endif // CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H |
82 changes: 82 additions & 0 deletions
82
compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#ifndef MLIR_DIALECT_RT_TRANSFORMS_PASSES | ||
#define MLIR_DIALECT_RT_TRANSFORMS_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def HoistAwaitFuturePass : Pass<"hoist-await-future", "mlir::func::FuncOp"> { | ||
let summary = "Hoists `RT.await_future` operations whose results are yielded " | ||
"by `scf.forall` operations out of the loops"; | ||
let description = [{ | ||
Hoists `RT.await_future` operations whose results are yielded by | ||
scf.forall operations out of the loops in order to avoid | ||
over-synchronization of data-flow tasks. | ||
|
||
E.g., the following IR: | ||
|
||
``` | ||
scf.forall (%arg) in (16) | ||
shared_outs(%o1 = %sometensor, %o2 = %someothertensor) | ||
-> (tensor<...>, tensor<...>) | ||
{ | ||
... | ||
%rph = "RT.build_return_ptr_placeholder"() : | ||
() -> !RT.rtptr<!RT.future<tensor<...>>> | ||
"RT.create_async_task"(..., %rph, ...) { ... } : ... | ||
%future = "RT.deref_return_ptr_placeholder"(%rph) : | ||
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>> | ||
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...> | ||
... | ||
scf.forall.in_parallel { | ||
... | ||
tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] : | ||
tensor<...> into tensor<...> | ||
... | ||
} | ||
} | ||
``` | ||
|
||
is transformed into: | ||
|
||
``` | ||
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>> | ||
|
||
scf.forall (%arg) in (16) | ||
shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor) | ||
-> (tensor<...>, tensor<...>) | ||
{ | ||
... | ||
%rph = "RT.build_return_ptr_placeholder"() : | ||
() -> !RT.rtptr<!RT.future<tensor<...>>> | ||
"RT.create_async_task"(..., %rph, ...) { ... } : ... | ||
%future = "RT.deref_return_ptr_placeholder"(%rph) : | ||
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>> | ||
%wrappedfuture = tensor.from_elements %future : | ||
tensor<1x!RT.future<tensor<...>>> | ||
... | ||
scf.forall.in_parallel { | ||
... | ||
tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] : | ||
tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>> | ||
... | ||
} | ||
} | ||
|
||
scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) { | ||
%future = tensor.extract %tensoroffutures[%arg] : | ||
tensor<4x!RT.future<tensor<...>>> | ||
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...> | ||
scf.forall.in_parallel { | ||
tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] : | ||
tensor<...> into tensor<...> | ||
} | ||
} | ||
``` | ||
}]; | ||
let constructor = "mlir::concretelang::createHoistAwaitFuturePass()"; | ||
let dependentDialects = [ | ||
"mlir::func::FuncDialect", "mlir::scf::SCFDialect", | ||
"mlir::tensor::TensorDialect", "mlir::concretelang::RT::RTDialect" | ||
]; | ||
} | ||
|
||
#endif // MLIR_DIALECT_RT_TRANSFORMS_PASSES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 2 additions & 0 deletions
2
compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.