Skip to content

Commit

Permalink
fix(compiler): Support indirect references to IVs in indexes when hoi…
Browse files Browse the repository at this point in the history
…sting RT.await_future ops

Until now, the pass hoisting `RT.await_future` operations only
supports `tensor.parallel_insert_slice` operations that use loop
induction variables directly as indexes. Any more complex indexing
expressions produce a domination error, since a
`tensor.parallel_insert_slice` cloned by the pass into an additional
parallel for loop is left with references to values from the original
loop.

This change properly clones operations producing intermediate values
within the original parallel for loop and thus adds support for
indexing expressions that reference loops IVs only indirectly.
  • Loading branch information
andidr committed Apr 23, 2024
1 parent e238067 commit 0c36fa7
Showing 1 changed file with 115 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,97 @@
#include <concretelang/Dialect/RT/IR/RTOps.h>
#include <concretelang/Dialect/RT/Transforms/Passes.h>

#include <llvm/Support/Debug.h>
#include <mlir/Dialect/Utils/StaticValueUtils.h>
#include <mlir/Transforms/RegionUtils.h>

#include <iterator>
#include <optional>

namespace {
bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp);

// Checks if the operation `op` can be cloned safely for insertion
// into a new loop. That is, it must be above `forallOp` or if all of
// its operands only reference loop IVs from `forallOp`, values
// defined above `forallOp` or intermediate values within the body of
// `forallOp` with the same properties. Operations with regions are
// currently not supported.
bool isClonableIVOp(mlir::Operation *op, mlir::scf::ForallOp forallOp) {
return op->getParentRegion()->isAncestor(&forallOp.getRegion()) ||
(mlir::isPure(op) && op->getNumRegions() == 0 &&
llvm::all_of(op->getOperands(), [=](mlir::Value operand) {
return isClonableIVExpression(operand, forallOp);
}));
}

// Checks if a value `v` is a loop IV, a value defined above
// `forallOp` or if the defining operation fulfills the conditions of
// `isClonableIVOp`.
bool isClonableIVExpression(mlir::Value v, mlir::scf::ForallOp forallOp) {
if (llvm::any_of(forallOp.getInductionVars(),
[=](mlir::Value iv) { return v == iv; }))
return true;

if (mlir::areValuesDefinedAbove(mlir::ValueRange{v},
forallOp.getBodyRegion()))
return true;

if (v.getDefiningOp())
return isClonableIVOp(v.getDefiningOp(), forallOp);

return false;
}

mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v,
mlir::IRMapping &mapping,
mlir::scf::ForallOp forallOp);

// Clones an operation `op` for insertion into a new loop
mlir::Operation *cloneIVOp(mlir::IRRewriter &rewriter, mlir::Operation *op,
mlir::IRMapping &mapping,
mlir::scf::ForallOp forallOp) {
assert(mlir::isPure(op));

for (mlir::Value operand : op->getOperands()) {
if (!mapping.contains(operand) &&
!mlir::areValuesDefinedAbove(mlir::ValueRange{operand},
forallOp.getBodyRegion())) {
cloneIVExpression(rewriter, operand, mapping, forallOp);
}
}

return rewriter.cloneWithoutRegions(*op, mapping);
}

// If `v` can be referenced safely from a new loop, `v` is returned
// directly. If not, its defining ops are recursively cloned.
mlir::Value cloneIVExpression(mlir::IRRewriter &rewriter, mlir::Value v,
mlir::IRMapping &mapping,
mlir::scf::ForallOp forallOp) {
if (mapping.contains(v))
return mapping.lookup(v);

if (mlir::areValuesDefinedAbove(mlir::ValueRange{v},
forallOp.getBodyRegion())) {
return v;
}

mlir::Operation *definingOp = v.getDefiningOp();

assert(definingOp);

mlir::Operation *clonedOp =
cloneIVOp(rewriter, definingOp, mapping, forallOp);

for (auto [res, cloneRes] :
llvm::zip_equal(definingOp->getResults(), clonedOp->getResults())) {
mapping.map(res, cloneRes);
}

return mapping.lookup(v);
}

struct HoistAwaitFuturePass
: public HoistAwaitFuturePassBase<HoistAwaitFuturePass> {
// Checks if all values of `a` are sizes of a non-dynamic dimensions
Expand Down Expand Up @@ -46,6 +131,24 @@ struct HoistAwaitFuturePass
if (!parallelInsertSliceOp)
return;

// Make sure that all indexes, offsets and strides used by the
// parallel insert slice op depend only on IVs of the forall, on
// intermediate values produced in the body or on values defined
// above.
auto isAttrOrClonableIVExpression = [=](mlir::OpFoldResult ofr) {
return ofr.is<mlir::Attribute>() ||
isClonableIVExpression(ofr.dyn_cast<mlir::Value>(), forallOp);
};

if (!llvm::all_of(parallelInsertSliceOp.getMixedOffsets(),
isAttrOrClonableIVExpression) ||
!llvm::all_of(parallelInsertSliceOp.getMixedStrides(),
isAttrOrClonableIVExpression) ||
!llvm::all_of(parallelInsertSliceOp.getMixedSizes(),
isAttrOrClonableIVExpression)) {
return;
}

// Make sure that the original tensor into which the
// synchronized values are inserted is a region out argument of
// the forall op and thus being written to concurrently
Expand Down Expand Up @@ -207,6 +310,18 @@ struct HoistAwaitFuturePass
syncMapping.map(parallelInsertSliceOp.getSource(),
newAwaitFutureOp.getResult());

auto addMapping = [&](llvm::ArrayRef<mlir::OpFoldResult> ofrs) {
for (mlir::OpFoldResult ofr : ofrs) {
if (mlir::Value v = ofr.dyn_cast<mlir::Value>())
syncMapping.map(
v, cloneIVExpression(rewriter, v, syncMapping, forallOp));
}
};

addMapping(parallelInsertSliceOp.getMixedOffsets());
addMapping(parallelInsertSliceOp.getMixedStrides());
addMapping(parallelInsertSliceOp.getMixedSizes());

mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator();
rewriter.setInsertionPointToStart(syncTerminator.getBody());
rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping);
Expand Down

0 comments on commit 0c36fa7

Please sign in to comment.