Skip to content

Commit

Permalink
Convert unary generic to inplace form (#955)
Browse files Browse the repository at this point in the history
Adds pattern to rewrite unary linalg.generic op to operate inplace that
is to replace the output with its input when possible.
Additionally, `linalg-convert-add-in-place` is refactored and
generalized into `convert-linalg-to-inplace` to enable the pass to hold
various patterns.

The new pattern allows to favor inplace bufferization when possible.
  • Loading branch information
adam-smnk authored Aug 9, 2024
1 parent 6bdaf15 commit 619fde1
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 142 deletions.
9 changes: 5 additions & 4 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,14 @@ def LinalgConvertCompareSelectToMaximumfPass: Pass<"linalg-convert-compare-selec
"arith::ArithDialect"];
}

def ConvertAddInplacePass: Pass<"linalg-convert-add-in-place",
def ConvertLinalgToInplace: Pass<"convert-linalg-to-inplace",
"func::FuncOp">{
let summary = "Convert linalg add to in-place operation";
let summary = "Convert linalg ops to inplace operation";
let description = [{
Convert linalg add to in-place update operation.
Convert linalg ops to inplace update operation.
}];
let dependentDialects = ["linalg::LinalgDialect"];
let dependentDialects = ["linalg::LinalgDialect",
"arith::ArithDialect"];
}

def TppRunnerWrapper : Pass<"tpp-runner-wrapper", "ModuleOp">{
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct DefaultTppPasses
pm.addPass(createCleanup());
} else {
pm.addPass(createFoldIntoEltwise());
pm.addNestedPass<func::FuncOp>(createConvertAddInplacePass());
pm.addNestedPass<func::FuncOp>(createConvertLinalgToInplace());
// Convert linalg.batch_matmul to linalg.matmul.
pm.addPass(createRewriteBatchMatmulToMatmul());

Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ add_mlir_library(TPPTransforms
IntelAMXTileConfig.cpp
IntelAMXTileConfigHoisting.cpp
LinalgConvertCompareSelectToMaximumfPass.cpp
ConvertAddInplacePass.cpp
ConvertLinalgToInplace.cpp
FoldIntoEltwise.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
91 changes: 0 additions & 91 deletions lib/TPP/Transforms/ConvertAddInplacePass.cpp

This file was deleted.

147 changes: 147 additions & 0 deletions lib/TPP/Transforms/ConvertLinalgToInplace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
//===-ConvertLinalgToInplace.cpp ---------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_CONVERTLINALGTOINPLACE
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;

namespace {

struct ConvertAddInplace : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {

if (op.getBody()->getOperations().size() != 2)
return failure();
auto addf = dyn_cast<arith::AddFOp>(&op.getBody()->getOperations().front());
if (!addf)
return failure();
if (op.getNumOperands() == 2)
return failure();
// TODO: This needs to be changed in the future to a detailed analysis that
// checks if the second input is not used subsequently
if (op.getInputs()[0] == op.getInputs()[1])
return failure();
SmallVector<AffineMap> indexingMaps;
SmallVector<utils::IteratorType> iteratorTypes;
for (auto iteratorTypesArray : op.getIteratorTypesArray()) {
iteratorTypes.push_back(iteratorTypesArray);
}

Value inputs, outputs;
// Check which input is marked as non-broadcastable
if (op.getIndexingMapsArray()[1] ==
rewriter.getMultiDimIdentityMap(
op.getIndexingMapsArray()[1].getNumDims())) {
indexingMaps.push_back(op.getIndexingMapsArray()[0]);
indexingMaps.push_back(op.getIndexingMapsArray()[1]);
inputs = op.getInputs()[0];
outputs = op.getInputs()[1];
} else {
indexingMaps.push_back(op.getIndexingMapsArray()[1]);
indexingMaps.push_back(op.getIndexingMapsArray()[0]);
inputs = op.getInputs()[1];
outputs = op.getInputs()[0];
}
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, op.getResultTypes(), inputs, outputs, indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto scalarOp = builder.create<arith::AddFOp>(loc, regionArgs);
builder.create<linalg::YieldOp>(loc, scalarOp.getResult());
});
return success();
}
};

struct EltwiseUnaryGenericToInplace
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(genericOp, "expects tensor semantics");

if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
return rewriter.notifyMatchFailure(genericOp, "not a unary operation");

if (genericOp.getInputs()[0].getType() !=
genericOp.getOutputs()[0].getType())
return rewriter.notifyMatchFailure(
genericOp, "input type does not match the output");

// Elementwise operation guarantees that all output elements are updated.
// The output initial values can be ignored and the output buffer can be
// replaced if the output is not used (write only).
if (!linalg::isElementwise(genericOp))
return rewriter.notifyMatchFailure(genericOp,
"not an elementwise operation");
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
return rewriter.notifyMatchFailure(genericOp,
"expects output to be unused");

// Elementwise operation still allows different indexing for its input e.g.,
// one of the dimensions can be fixed for the input.
// Ensure that indexing maps of both operands are be equal. Otherwise,
// the input cannot replace the output buffer.
SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
if (maps[0] != maps[1])
return rewriter.notifyMatchFailure(genericOp,
"expects matching indexing maps");

// Use the input value directly as the output.
ValueRange outputs = genericOp.getInputs();
SmallVector<Type> resultTypes = TypeRange(ValueRange{outputs});
SmallVector<AffineMap> indexingMaps{maps[1]};

auto newGeneric = rewriter.create<linalg::GenericOp>(
genericOp.getLoc(), resultTypes, /*inputs=*/ValueRange{}, outputs,
indexingMaps, genericOp.getIteratorTypesArray());
rewriter.inlineRegionBefore(genericOp->getRegion(0), newGeneric.getRegion(),
newGeneric.getRegion().begin());

// Replace input block arguments usage with the output block argument.
Block *body = newGeneric.getBody();
rewriter.replaceAllUsesWith(body->getArguments()[0],
body->getArguments()[1]);
body->eraseArgument(0);

rewriter.replaceOp(genericOp, newGeneric->getResults());

return success();
}
};

struct ConvertLinalgToInplace
: public tpp::impl::ConvertLinalgToInplaceBase<ConvertLinalgToInplace> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<ConvertAddInplace, EltwiseUnaryGenericToInplace>(
patterns.getContext());
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateCombinePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace
45 changes: 0 additions & 45 deletions test/Passes/convert-add-in-place.mlir

This file was deleted.

Loading

0 comments on commit 619fde1

Please sign in to comment.