Skip to content

Commit

Permalink
fix(compiler): Support indirect references to IVs in indexes when hoi…
Browse files Browse the repository at this point in the history
…sting RT.await_future ops

Until now, the pass hoisting `RT.await_future` operations only
supports `tensor.parallel_insert_slice` operations that use loop
induction variables directly as indexes. Any more complex indexing
expressions produce a domination error, since a
`tensor.parallel_insert_slice` cloned by the pass into an additional
parallel for loop is left with references to values from the original
loop.

This change properly clones operations producing intermediate values
within the original parallel for loop and thus adds support for
indexing expressions that reference loops IVs only indirectly.
  • Loading branch information
andidr committed Apr 23, 2024
1 parent 16f0041 commit d48e7cb
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,97 @@
#include <concretelang/Dialect/RT/IR/RTOps.h>
#include <concretelang/Dialect/RT/Transforms/Passes.h>

#include <llvm/Support/Debug.h>
#include <mlir/Dialect/Utils/StaticValueUtils.h>
#include <mlir/Transforms/RegionUtils.h>

#include <iterator>
#include <optional>

namespace {
bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp);

// Checks if the operation `op` can be cloned safely for insertion
// into a new loop. That is, it must be above `forallOp` or if all of
// its operands only reference loop IVs from `forallOp`, values
// defined above `forallOp` or intermediate values within the body of
// `forallOp` with the same properties. Operations with regions are
// currently not supported.
bool isClonableIVOp(mlir::Operation *op, mlir::scf::ForallOp forallOp) {
return op->getParentRegion()->isAncestor(&forallOp.getRegion()) ||
(mlir::isPure(op) && op->getNumRegions() == 0 &&
llvm::all_of(op->getOperands(), [=](mlir::Value operand) {
return isClonableIVExpression(operand, forallOp);
}));
}

// Checks if a value `v` is a loop IV, a value defined above
// `forallOp` or if the defining operation fulfills the conditions of
// `isClonableIVOp`.
bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp) {
if (llvm::any_of(forallOp.getInductionVars(),
[=](mlir::Value iv) { return v == iv; }))
return true;

if (mlir::areValuesDefinedAbove(mlir::ValueRange{v},
forallOp.getBodyRegion()))
return true;

if (v.getDefiningOp())
return isClonableIVOp(v.getDefiningOp(), forallOp);

return false;
}

mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v,
mlir::IRMapping &mapping,
mlir::scf::ForallOp forallOp);

// Clones an operation `op` for insertion into a new loop
mlir::Operation *cloneIVOp(mlir::IRRewriter &rewriter, mlir::Operation *op,
mlir::IRMapping &mapping,
mlir::scf::ForallOp forallOp) {
assert(mlir::isPure(op));

for (mlir::Value operand : op->getOperands()) {
if (!mapping.contains(operand) &&
!mlir::areValuesDefinedAbove(mlir::ValueRange{operand},
forallOp.getBodyRegion())) {
cloneIVExpression(rewriter, operand, mapping, forallOp);
}
}

return rewriter.cloneWithoutRegions(*op, mapping);
}

// If `v` can be referenced safely from a new loop, `v` is returned
// directly. If not, its defining ops are recursively cloned.
mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v,
mlir::IRMapping &mapping,
mlir::scf::ForallOp forallOp) {
if (mapping.contains(v))
return mapping.lookup(v);

if (mlir::areValuesDefinedAbove(mlir::ValueRange{v},
forallOp.getBodyRegion())) {
return v;
}

mlir::Operation *definingOp = v.getDefiningOp();

assert(definingOp);

mlir::Operation *clonedOp =
cloneIVOp(rewriter, definingOp, mapping, forallOp);

for (auto [res, cloneRes] :
llvm::zip_equal(definingOp->getResults(), clonedOp->getResults())) {
mapping.map(res, cloneRes);
}

return mapping.lookup(v);
}

struct HoistAwaitFuturePass
: public HoistAwaitFuturePassBase<HoistAwaitFuturePass> {
// Checks if all values of `a` are sizes of a non-dynamic dimensions
Expand Down Expand Up @@ -46,6 +131,24 @@ struct HoistAwaitFuturePass
if (!parallelInsertSliceOp)
return;

// Make sure that all indexes, offsets and strides used by the
// parallel insert slice op depend only on IVs of the forall, on
// intermediate values produced in the body or on values defined
// above.
auto isAttrOrClonableIVExpression = [=](mlir::OpFoldResult ofr) {
return ofr.is<mlir::Attribute>() ||
isClonableIVExpression(ofr.dyn_cast<mlir::Value>(), forallOp);
};

if (!llvm::all_of(parallelInsertSliceOp.getMixedOffsets(),
isAttrOrClonableIVExpression) ||
!llvm::all_of(parallelInsertSliceOp.getMixedStrides(),
isAttrOrClonableIVExpression) ||
!llvm::all_of(parallelInsertSliceOp.getMixedSizes(),
isAttrOrClonableIVExpression)) {
return;
}

// Make sure that the original tensor into which the
// synchronized values are inserted is a region out argument of
// the forall op and thus being written to concurrently
Expand Down Expand Up @@ -207,6 +310,18 @@ struct HoistAwaitFuturePass
syncMapping.map(parallelInsertSliceOp.getSource(),
newAwaitFutureOp.getResult());

auto addMapping = [&](llvm::ArrayRef<mlir::OpFoldResult> ofrs) {
for (mlir::OpFoldResult ofr : ofrs) {
if (mlir::Value v = ofr.dyn_cast<mlir::Value>())
syncMapping.map(
v, cloneIVExpression(rewriter, v, syncMapping, forallOp));
}
};

addMapping(parallelInsertSliceOp.getMixedOffsets());
addMapping(parallelInsertSliceOp.getMixedStrides());
addMapping(parallelInsertSliceOp.getMixedSizes());

mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator();
rewriter.setInsertionPointToStart(syncTerminator.getBody());
rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: concretecompiler --action=dump-fhe-df-parallelized %s --optimizer-strategy=dag-mono --parallelize --passes hoist-await-future --skip-program-info | FileCheck %s

func.func @_dfr_DFT_work_function__main0(%arg0: !RT.rtptr<tensor<2x!FHE.eint<6>>>, %arg1: !RT.rtptr<tensor<2x!FHE.eint<6>>>, %arg2: !RT.rtptr<tensor<2xi7>>, %arg3: !RT.rtptr<tensor<2x!FHE.eint<6>>>) attributes {_dfr_work_function_attribute} {
return
}

// CHECK: %[[V3:.*]] = scf.forall (%[[Varg2:.*]]) in (8) shared_outs(%[[Varg3:.*]] = %[[V0:.*]]) -> (tensor<16x!FHE.eint<6>>) {
// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[V2:.*]]{{\[}}%[[Varg2]]{{\]}} : tensor<8x!RT.future<tensor<2x!FHE.eint<6>>>>
// CHECK-NEXT: %[[V4:.*]] = "RT.await_future"(%[[Vextracted]]) : (!RT.future<tensor<2x!FHE.eint<6>>>) -> tensor<2x!FHE.eint<6>>
// CHECK-NEXT: %[[V5:.*]] = affine.apply #map(%[[Varg2]])
// CHECK-NEXT: scf.forall.in_parallel {
// CHECK-NEXT: tensor.parallel_insert_slice %[[V4]] into %[[Varg3]]{{\[}}%[[V5]]{{\] \[2\] \[1\]}} : tensor<2x!FHE.eint<6>> into tensor<16x!FHE.eint<6>>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V3]] : tensor<16x!FHE.eint<6>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @main(%arg0: tensor<16x!FHE.eint<6>>, %arg1: tensor<16xi7>) -> tensor<16x!FHE.eint<6>> {
%f = constant @_dfr_DFT_work_function__main0 : (!RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2xi7>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>) -> ()
"RT.register_task_work_function"(%f) : ((!RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2xi7>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>) -> ()) -> ()
%0 = "FHE.zero_tensor"() : () -> tensor<16x!FHE.eint<6>>
%1 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %0) -> (tensor<16x!FHE.eint<6>>) {
%2 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2)
%3 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2)
%4 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2)
%extracted_slice = tensor.extract_slice %arg0[%2] [2] [1] : tensor<16x!FHE.eint<6>> to tensor<2x!FHE.eint<6>>
%extracted_slice_0 = tensor.extract_slice %arg1[%3] [2] [1] : tensor<16xi7> to tensor<2xi7>
%extracted_slice_1 = tensor.extract_slice %arg3[%4] [2] [1] : tensor<16x!FHE.eint<6>> to tensor<2x!FHE.eint<6>>
%c0_i64 = arith.constant 0 : i64
%5 = "RT.make_ready_future"(%extracted_slice, %c0_i64) : (tensor<2x!FHE.eint<6>>, i64) -> !RT.future<tensor<2x!FHE.eint<6>>>
%c0_i64_2 = arith.constant 0 : i64
%6 = "RT.make_ready_future"(%extracted_slice_0, %c0_i64_2) : (tensor<2xi7>, i64) -> !RT.future<tensor<2xi7>>
%c0_i64_3 = arith.constant 0 : i64
%7 = "RT.make_ready_future"(%extracted_slice_1, %c0_i64_3) : (tensor<2x!FHE.eint<6>>, i64) -> !RT.future<tensor<2x!FHE.eint<6>>>
%f_4 = func.constant @_dfr_DFT_work_function__main0 : (!RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2xi7>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>) -> ()
%c3_i64 = arith.constant 3 : i64
%c1_i64 = arith.constant 1 : i64
%8 = "RT.build_return_ptr_placeholder"() : () -> !RT.rtptr<!RT.future<tensor<2x!FHE.eint<6>>>>
"RT.create_async_task"(%f_4, %c3_i64, %c1_i64, %8, %5, %6, %7) {workfn = @_dfr_DFT_work_function__main0} : ((!RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>, !RT.rtptr<tensor<2xi7>>, !RT.rtptr<tensor<2x!FHE.eint<6>>>) -> (), i64, i64, !RT.rtptr<!RT.future<tensor<2x!FHE.eint<6>>>>, !RT.future<tensor<2x!FHE.eint<6>>>, !RT.future<tensor<2xi7>>, !RT.future<tensor<2x!FHE.eint<6>>>) -> ()
%9 = "RT.deref_return_ptr_placeholder"(%8) : (!RT.rtptr<!RT.future<tensor<2x!FHE.eint<6>>>>) -> !RT.future<tensor<2x!FHE.eint<6>>>
%10 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg2)
%11 = "RT.await_future"(%9) : (!RT.future<tensor<2x!FHE.eint<6>>>) -> tensor<2x!FHE.eint<6>>
scf.forall.in_parallel {
tensor.parallel_insert_slice %11 into %arg3[%10] [2] [1] : tensor<2x!FHE.eint<6>> into tensor<16x!FHE.eint<6>>
}
}
return %1 : tensor<16x!FHE.eint<6>>
}

0 comments on commit d48e7cb

Please sign in to comment.