From 3cd3dff92a37afb1ec6d5a6dbb5c3b80d6035572 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 29 Jul 2024 15:52:02 +0200 Subject: [PATCH] fix(compiler): Convert scf.for to scf.parallel only if parallel attribute is true The pattern converting `scf.for` operations to `scf.parallel` operations from `lib/Transforms/ForLoopToParallel.cpp` contains an assertion that ensures that the source operation does not have any iteration arguments in order to keep the conversion as simple as possible. However, if the attribute `parallel` of the source operation is `false`, the operation is replaced with an identical clone and the conversion could be treated as a no-op. This change modifies the pattern, such that it simply fails if `parallel` is `false`, making the check for the absence of iteration arguments unnecessary and avoiding unnecessary bailouts by the compiler. --- .../lib/Transforms/ForLoopToParallel.cpp | 45 +++++++------------ 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Transforms/ForLoopToParallel.cpp b/compilers/concrete-compiler/compiler/lib/Transforms/ForLoopToParallel.cpp index dcbdfa0282..104ee00929 100644 --- a/compilers/concrete-compiler/compiler/lib/Transforms/ForLoopToParallel.cpp +++ b/compilers/concrete-compiler/compiler/lib/Transforms/ForLoopToParallel.cpp @@ -23,39 +23,26 @@ class ForOpPattern : public mlir::OpRewritePattern { matchAndRewrite(mlir::scf::ForOp forOp, mlir::PatternRewriter &rewriter) const override { auto attr = forOp->getAttrOfType("parallel"); - if (attr == nullptr) { + + if (!attr || !attr.getValue()) { return mlir::failure(); } + assert(forOp.getRegionIterArgs().size() == 0 && "unexpecting iter args when loops are bufferized"); - if (attr.getValue()) { - rewriter.replaceOpWithNewOp( - forOp, mlir::ValueRange{forOp.getLowerBound()}, - mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(), - std::nullopt, - [&](mlir::OpBuilder &builder, mlir::Location location, - mlir::ValueRange indVar, mlir::ValueRange iterArgs) { - mlir::IRMapping map; - map.map(forOp.getInductionVar(), indVar.front()); - for (auto &op : forOp.getRegion().front()) { - auto newOp = builder.clone(op, map); - map.map(op.getResults(), newOp->getResults()); - } - }); - } else { - rewriter.replaceOpWithNewOp( - forOp, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), - std::nullopt, - [&](mlir::OpBuilder &builder, mlir::Location location, - mlir::Value indVar, mlir::ValueRange iterArgs) { - mlir::IRMapping map; - map.map(forOp.getInductionVar(), indVar); - for (auto &op : forOp.getRegion().front()) { - auto newOp = builder.clone(op, map); - map.map(op.getResults(), newOp->getResults()); - } - }); - } + + rewriter.replaceOpWithNewOp( + forOp, mlir::ValueRange{forOp.getLowerBound()}, + mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(), std::nullopt, + [&](mlir::OpBuilder &builder, mlir::Location location, + mlir::ValueRange indVar, mlir::ValueRange iterArgs) { + mlir::IRMapping map; + map.map(forOp.getInductionVar(), indVar.front()); + for (auto &op : forOp.getRegion().front()) { + auto newOp = builder.clone(op, map); + map.map(op.getResults(), newOp->getResults()); + } + }); return mlir::success(); }