diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp index f5e9784054..0aab7065c0 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp @@ -8,12 +8,97 @@ #include #include +#include #include +#include #include #include namespace { +bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp); + +// Checks if the operation `op` can be cloned safely for insertion +// into a new loop. That is, it must be above `forallOp` or if all of +// its operands only reference loop IVs from `forallOp`, values +// defined above `forallOp` or intermediate values within the body of +// `forallOp` with the same properties. Operations with regions are +// currently not supported. +bool isClonableIVOp(mlir::Operation *op, mlir::scf::ForallOp forallOp) { + return op->getParentRegion()->isAncestor(&forallOp.getRegion()) || + (mlir::isPure(op) && op->getNumRegions() == 0 && + llvm::all_of(op->getOperands(), [=](mlir::Value operand) { + return isClonableIVExpression(operand, forallOp); + })); +} + +// Checks if a value `v` is a loop IV, a value defined above +// `forallOp` or if the defining operation fulfills the conditions of +// `isClonableIVOp`. +bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp) { + if (llvm::any_of(forallOp.getInductionVars(), + [=](mlir::Value iv) { return v == iv; })) + return true; + + if (mlir::areValuesDefinedAbove(mlir::ValueRange{v}, + forallOp.getBodyRegion())) + return true; + + if (v.getDefiningOp()) + return isClonableIVOp(v.getDefiningOp(), forallOp); + + return false; +} + +mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v, + mlir::IRMapping &mapping, + mlir::scf::ForallOp forallOp); + +// Clones an operation `op` for insertion into a new loop +mlir::Operation *cloneIVOp(mlir::IRRewriter &rewriter, mlir::Operation *op, + mlir::IRMapping &mapping, + mlir::scf::ForallOp forallOp) { + assert(mlir::isPure(op)); + + for (mlir::Value operand : op->getOperands()) { + if (!mapping.contains(operand) && + !mlir::areValuesDefinedAbove(mlir::ValueRange{operand}, + forallOp.getBodyRegion())) { + cloneIVExpression(rewriter, operand, mapping, forallOp); + } + } + + return rewriter.cloneWithoutRegions(*op, mapping); +} + +// If `v` can be referenced safely from a new loop, `v` is returned +// directly. If not, its defining ops are recursively cloned. +mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v, + mlir::IRMapping &mapping, + mlir::scf::ForallOp forallOp) { + if (mapping.contains(v)) + return mapping.lookup(v); + + if (mlir::areValuesDefinedAbove(mlir::ValueRange{v}, + forallOp.getBodyRegion())) { + return v; + } + + mlir::Operation *definingOp = v.getDefiningOp(); + + assert(definingOp); + + mlir::Operation *clonedOp = + cloneIVOp(rewriter, definingOp, mapping, forallOp); + + for (auto [res, cloneRes] : + llvm::zip_equal(definingOp->getResults(), clonedOp->getResults())) { + mapping.map(res, cloneRes); + } + + return mapping.lookup(v); +} + struct HoistAwaitFuturePass : public HoistAwaitFuturePassBase { // Checks if all values of `a` are sizes of a non-dynamic dimensions @@ -46,6 +131,24 @@ struct HoistAwaitFuturePass if (!parallelInsertSliceOp) return; + // Make sure that all indexes, offsets and strides used by the + // parallel insert slice op depend only on IVs of the forall, on + // intermediate values produced in the body or on values defined + // above. + auto isAttrOrClonableIVExpression = [=](mlir::OpFoldResult ofr) { + return ofr.is() || + isClonableIVExpression(ofr.dyn_cast(), forallOp); + }; + + if (!llvm::all_of(parallelInsertSliceOp.getMixedOffsets(), + isAttrOrClonableIVExpression) || + !llvm::all_of(parallelInsertSliceOp.getMixedStrides(), + isAttrOrClonableIVExpression) || + !llvm::all_of(parallelInsertSliceOp.getMixedSizes(), + isAttrOrClonableIVExpression)) { + return; + } + // Make sure that the original tensor into which the // synchronized values are inserted is a region out argument of // the forall op and thus being written to concurrently @@ -207,6 +310,18 @@ struct HoistAwaitFuturePass syncMapping.map(parallelInsertSliceOp.getSource(), newAwaitFutureOp.getResult()); + auto addMapping = [&](llvm::ArrayRef ofrs) { + for (mlir::OpFoldResult ofr : ofrs) { + if (mlir::Value v = ofr.dyn_cast()) + syncMapping.map( + v, cloneIVExpression(rewriter, v, syncMapping, forallOp)); + } + }; + + addMapping(parallelInsertSliceOp.getMixedOffsets()); + addMapping(parallelInsertSliceOp.getMixedStrides()); + addMapping(parallelInsertSliceOp.getMixedSizes()); + mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator(); rewriter.setInsertionPointToStart(syncTerminator.getBody()); rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping); diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future_ivexpr.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future_ivexpr.mlir new file mode 100644 index 0000000000..8fcd8866b3 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/RT/hoist_await_future_ivexpr.mlir @@ -0,0 +1,48 @@ +// RUN: concretecompiler --action=dump-fhe-df-parallelized %s --optimizer-strategy=dag-mono --parallelize --passes hoist-await-future --skip-program-info | FileCheck %s + +func.func @_dfr_DFT_work_function__main0(%arg0: !RT.rtptr>>, %arg1: !RT.rtptr>>, %arg2: !RT.rtptr>, %arg3: !RT.rtptr>>) attributes {_dfr_work_function_attribute} { + return +} + +// CHECK: %[[V3:.*]] = scf.forall (%[[Varg2:.*]]) in (8) shared_outs(%[[Varg3:.*]] = %[[V0:.*]]) -> (tensor<16x!FHE.eint<6>>) { +// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[V2:.*]]{{\[}}%[[Varg2]]{{\]}} : tensor<8x!RT.future>>> +// CHECK-NEXT: %[[V4:.*]] = "RT.await_future"(%[[Vextracted]]) : (!RT.future>>) -> tensor<2x!FHE.eint<6>> +// CHECK-NEXT: %[[V5:.*]] = affine.apply #map(%[[Varg2]]) +// CHECK-NEXT: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %[[V4]] into %[[Varg3]]{{\[}}%[[V5]]{{\] \[2\] \[1\]}} : tensor<2x!FHE.eint<6>> into tensor<16x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V3]] : tensor<16x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } +func.func @main(%arg0: tensor<16x!FHE.eint<6>>, %arg1: tensor<16xi7>) -> tensor<16x!FHE.eint<6>> { + %f = constant @_dfr_DFT_work_function__main0 : (!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> () + "RT.register_task_work_function"(%f) : ((!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> ()) -> () + %0 = "FHE.zero_tensor"() : () -> tensor<16x!FHE.eint<6>> + %1 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %0) -> (tensor<16x!FHE.eint<6>>) { + %2 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %3 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %4 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %extracted_slice = tensor.extract_slice %arg0[%2] [2] [1] : tensor<16x!FHE.eint<6>> to tensor<2x!FHE.eint<6>> + %extracted_slice_0 = tensor.extract_slice %arg1[%3] [2] [1] : tensor<16xi7> to tensor<2xi7> + %extracted_slice_1 = tensor.extract_slice %arg3[%4] [2] [1] : tensor<16x!FHE.eint<6>> to tensor<2x!FHE.eint<6>> + %c0_i64 = arith.constant 0 : i64 + %5 = "RT.make_ready_future"(%extracted_slice, %c0_i64) : (tensor<2x!FHE.eint<6>>, i64) -> !RT.future>> + %c0_i64_2 = arith.constant 0 : i64 + %6 = "RT.make_ready_future"(%extracted_slice_0, %c0_i64_2) : (tensor<2xi7>, i64) -> !RT.future> + %c0_i64_3 = arith.constant 0 : i64 + %7 = "RT.make_ready_future"(%extracted_slice_1, %c0_i64_3) : (tensor<2x!FHE.eint<6>>, i64) -> !RT.future>> + %f_4 = func.constant @_dfr_DFT_work_function__main0 : (!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> () + %c3_i64 = arith.constant 3 : i64 + %c1_i64 = arith.constant 1 : i64 + %8 = "RT.build_return_ptr_placeholder"() : () -> !RT.rtptr>>> + "RT.create_async_task"(%f_4, %c3_i64, %c1_i64, %8, %5, %6, %7) {workfn = @_dfr_DFT_work_function__main0} : ((!RT.rtptr>>, !RT.rtptr>>, !RT.rtptr>, !RT.rtptr>>) -> (), i64, i64, !RT.rtptr>>>, !RT.future>>, !RT.future>, !RT.future>>) -> () + %9 = "RT.deref_return_ptr_placeholder"(%8) : (!RT.rtptr>>>) -> !RT.future>> + %10 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2) + %11 = "RT.await_future"(%9) : (!RT.future>>) -> tensor<2x!FHE.eint<6>> + scf.forall.in_parallel { + tensor.parallel_insert_slice %11 into %arg3[%10] [2] [1] : tensor<2x!FHE.eint<6>> into tensor<16x!FHE.eint<6>> + } + } + return %1 : tensor<16x!FHE.eint<6>> +}