From d48e7cb571691bca5816628658cd4c9bac3b419a Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 23 Apr 2024 06:17:27 +0200 Subject: [PATCH] fix(compiler): Support indirect references to IVs in indexes when hoisting RT.await_future ops Until now, the pass hoisting `RT.await_future` operations only supports `tensor.parallel_insert_slice` operations that use loop induction variables directly as indexes. Any more complex indexing expressions produce a domination error, since a `tensor.parallel_insert_slice` cloned by the pass into an additional parallel for loop is left with references to values from the original loop. This change properly clones operations producing intermediate values within the original parallel for loop and thus adds support for indexing expressions that reference loops IVs only indirectly. --- .../RT/Transforms/HoistAwaitFuturePass.cpp | 115 ++++++++++++++++++ .../Dialect/RT/hoist_await_future_ivexpr.mlir | 48 ++++++++ 2 files changed, 163 insertions(+) create mode 100644 compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future_ivexpr.mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp index f5e9784054..0aab7065c0 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp @@ -8,12 +8,97 @@ #include #include +#include #include +#include #include #include namespace { +bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp); + +// Checks if the operation `op` can be cloned safely for insertion +// into a new loop. That is, it must be above `forallOp` or if all of +// its operands only reference loop IVs from `forallOp`, values +// defined above `forallOp` or intermediate values within the body of +// `forallOp` with the same properties. Operations with regions are +// currently not supported. +bool isClonableIVOp(mlir::Operation *op, mlir::scf::ForallOp forallOp) { + return op->getParentRegion()->isAncestor(&forallOp.getRegion()) || + (mlir::isPure(op) && op->getNumRegions() == 0 && + llvm::all_of(op->getOperands(), [=](mlir::Value operand) { + return isClonableIVExpression(operand, forallOp); + })); +} + +// Checks if a value `v` is a loop IV, a value defined above +// `forallOp` or if the defining operation fulfills the conditions of +// `isClonableIVOp`. +bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp) { + if (llvm::any_of(forallOp.getInductionVars(), + [=](mlir::Value iv) { return v == iv; })) + return true; + + if (mlir::areValuesDefinedAbove(mlir::ValueRange{v}, + forallOp.getBodyRegion())) + return true; + + if (v.getDefiningOp()) + return isClonableIVOp(v.getDefiningOp(), forallOp); + + return false; +} + +mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v, + mlir::IRMapping &mapping, + mlir::scf::ForallOp forallOp); + +// Clones an operation `op` for insertion into a new loop +mlir::Operation *cloneIVOp(mlir::IRRewriter &rewriter, mlir::Operation *op, + mlir::IRMapping &mapping, + mlir::scf::ForallOp forallOp) { + assert(mlir::isPure(op)); + + for (mlir::Value operand : op->getOperands()) { + if (!mapping.contains(operand) && + !mlir::areValuesDefinedAbove(mlir::ValueRange{operand}, + forallOp.getBodyRegion())) { + cloneIVExpression(rewriter, operand, mapping, forallOp); + } + } + + return rewriter.cloneWithoutRegions(*op, mapping); +} + +// If `v` can be referenced safely from a new loop, `v` is returned +// directly. If not, its defining ops are recursively cloned. +mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v, + mlir::IRMapping &mapping, + mlir::scf::ForallOp forallOp) { + if (mapping.contains(v)) + return mapping.lookup(v); + + if (mlir::areValuesDefinedAbove(mlir::ValueRange{v}, + forallOp.getBodyRegion())) { + return v; + } + + mlir::Operation *definingOp = v.getDefiningOp(); + + assert(definingOp); + + mlir::Operation *clonedOp = + cloneIVOp(rewriter, definingOp, mapping, forallOp); + + for (auto [res, cloneRes] : + llvm::zip_equal(definingOp->getResults(), clonedOp->getResults())) { + mapping.map(res, cloneRes); + } + + return mapping.lookup(v); +} + struct HoistAwaitFuturePass : public HoistAwaitFuturePassBase { // Checks if all values of `a` are sizes of a non-dynamic dimensions @@ -46,6 +131,24 @@ struct HoistAwaitFuturePass if (!parallelInsertSliceOp) return; + // Make sure that all indexes, offsets and strides used by the + // parallel insert slice op depend only on IVs of the forall, on + // intermediate values produced in the body or on values defined + // above. + auto isAttrOrClonableIVExpression = [=](mlir::OpFoldResult ofr) { + return ofr.is() || + isClonableIVExpression(ofr.dyn_cast(), forallOp); + }; + + if (!llvm::all_of(parallelInsertSliceOp.getMixedOffsets(), + isAttrOrClonableIVExpression) || + !llvm::all_of(parallelInsertSliceOp.getMixedStrides(), + isAttrOrClonableIVExpression) || + !llvm::all_of(parallelInsertSliceOp.getMixedSizes(), + isAttrOrClonableIVExpression)) { + return; + } + // Make sure that the original tensor into which the // synchronized values are inserted is a region out argument of // the forall op and thus being written to concurrently @@ -207,6 +310,18 @@ struct HoistAwaitFuturePass syncMapping.map(parallelInsertSliceOp.getSource(), newAwaitFutureOp.getResult()); + auto addMapping = [&](llvm::ArrayRef ofrs) { + for (mlir::OpFoldResult ofr : ofrs) { + if (mlir::Value v = ofr.dyn_cast()) + syncMapping.map( + v, cloneIVExpression(rewriter, v, syncMapping, forallOp)); + } + }; + + addMapping(parallelInsertSliceOp.getMixedOffsets()); + addMapping(parallelInsertSliceOp.getMixedStrides()); + addMapping(parallelInsertSliceOp.getMixedSizes()); + mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator(); rewriter.setInsertionPointToStart(syncTerminator.getBody()); rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping); diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future_ivexpr.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future_ivexpr.mlir new file mode 100644 index 0000000000..8fcd8866b3 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future_ivexpr.mlir @@ -0,0 +1,48 @@ +// RUN: concretecompiler --action=dump-fhe-df-parallelized %s --optimizer-strategy=dag-mono --parallelize --passes hoist-await-future --skip-program-info | FileCheck %s + +func.func @_dfr_DFT_work_function__main0(%arg0: !RT.rtptr>>, %arg1: !RT.rtptr>>, %arg2: !RT.rtptr>, %arg3: !RT.rtptr>>) attributes {_dfr_work_function_attribute} { + return +} + +// CHECK: %[[V3:.*]] = scf.forall (%[[Varg2:.*]]) in (8) shared_outs(%[[Varg3:.*]] = %[[V0:.*]]) -> (tensor<16x!FHE.eint<6>>) { +// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[V2:.*]]{{\[}}%[[Varg2]]{{\]}} : tensor<8x!RT.future>>> +// CHECK-NEXT: %[[V4:.*]] = "RT.await_future"(%[[Vextracted]]) : (!RT.future>>) -> tensor<2x!FHE.eint<6>> +// CHECK-NEXT: %[[V5:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %[[V4]] into %[[Varg3]]{{\[}}%[[V5]]{{\] \[2\] \[1\]}} : tensor<2x!FHE.eint<6>> into tensor<16x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V3]] : tensor<16x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } +func.func @main(%arg0: tensor<16x!FHE.eint<6>>, %arg1: tensor<16xi7>) -> tensor<16x!FHE.eint<6>> { + %f = constant @_dfr_DFT_work_function__main0 : (!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> () + "RT.register_task_work_function"(%f) : ((!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> ()) -> () + %0 = "FHE.zero_tensor"() : () -> tensor<16x!FHE.eint<6>> + %1 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %0) -> (tensor<16x!FHE.eint<6>>) { + %2 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %3 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %4 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %extracted_slice = tensor.extract_slice %arg0[%2] [2] [1] : tensor<16x!FHE.eint<6>> to tensor<2x!FHE.eint<6>> + %extracted_slice_0 = tensor.extract_slice %arg1[%3] [2] [1] : tensor<16xi7> to tensor<2xi7> + %extracted_slice_1 = tensor.extract_slice %arg3[%4] [2] [1] : tensor<16x!FHE.eint<6>> to tensor<2x!FHE.eint<6>> + %c0_i64 = arith.constant 0 : i64 + %5 = "RT.make_ready_future"(%extracted_slice, %c0_i64) : (tensor<2x!FHE.eint<6>>, i64) -> !RT.future>> + %c0_i64_2 = arith.constant 0 : i64 + %6 = "RT.make_ready_future"(%extracted_slice_0, %c0_i64_2) : (tensor<2xi7>, i64) -> !RT.future> + %c0_i64_3 = arith.constant 0 : i64 + %7 = "RT.make_ready_future"(%extracted_slice_1, %c0_i64_3) : (tensor<2x!FHE.eint<6>>, i64) -> !RT.future>> + %f_4 = func.constant @_dfr_DFT_work_function__main0 : (!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> () + %c3_i64 = arith.constant 3 : i64 + %c1_i64 = arith.constant 1 : i64 + %8 = "RT.build_return_ptr_placeholder"() : () -> !RT.rtptr>>> + "RT.create_async_task"(%f_4, %c3_i64, %c1_i64, %8, %5, %6, %7) {workfn = @_dfr_DFT_work_function__main0} : ((!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> (), i64, i64, !RT.rtptr>>>, !RT.future>>, !RT.future>, !RT.future>>) -> () + %9 = "RT.deref_return_ptr_placeholder"(%8) : (!RT.rtptr>>>) -> !RT.future>> + %10 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %11 = "RT.await_future"(%9) : (!RT.future>>) -> tensor<2x!FHE.eint<6>> + scf.forall.in_parallel { + tensor.parallel_insert_slice %11 into %arg3[%10] [2] [1] : tensor<2x!FHE.eint<6>> into tensor<16x!FHE.eint<6>> + } + } + return %1 : tensor<16x!FHE.eint<6>> +}