From a53153a80427d83d7c524f54bfbb6d0b8544dd8d Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Sat, 30 Sep 2023 00:16:49 +0530 Subject: [PATCH] [LinalgExt][Bufferization] Clean-up bufferization pass for LinalgExt (#15040) -- This commit cleans up bufferization pass for LinalgExt ops by making use of `DstBufferizableOpInterfaceExternalModel`. Signed-off-by: Abhishek Varma --- .../Interfaces/BufferizationInterfaces.cpp | 54 +++---------------- 1 file changed, 6 insertions(+), 48 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp index bc7bbe5c2e54..89c675aae093 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp @@ -15,6 +15,7 @@ #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" @@ -344,59 +345,16 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, /// a new op that operates entirely on memrefs. template struct LinalgExtOpInterface - : public BufferizableOpInterface::ExternalModel, - OpTy> { + : public bufferization::DstBufferizableOpInterfaceExternalModel< + LinalgExtOpInterface, OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - // TODO: Implement payloadUsesValueFromOperand for individual ops. There - // are a limited number of LinalgExt ops, so we hardcode them here. We don't - // expect to add more LinalgExt ops. - if (!cast(op).isDpsInit(&opOperand)) - return true; + // TODO: Revisit this for Scatter/ReverseOp. We can then get rid of + // `bufferizesToMemoryRead` completely. return !isa(op); } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - // Operand is written to if it has an aliasing OpResult. - auto bufferizableOp = cast(op); - return !bufferizableOp.getAliasingValues(opOperand, state) - .getAliases() - .empty(); - } - - bufferization::AliasingOpOperandList - getAliasingOpOperands(Operation *op, Value value, - const AnalysisState &state) const { - size_t resultNum = std::distance(op->getOpResults().begin(), - llvm::find(op->getOpResults(), value)); - // The i-th OpResult may alias with the i-th "out" tensor. - return {AliasingOpOperand( - &cast(op).getDpsInitsMutable()[resultNum], - BufferRelation::Equivalent, - /*isDefinite=*/false)}; - } - - bufferization::AliasingValueList - getAliasingValues(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - auto dspOp = cast(op); - - // The i-th "out" tensor may alias with the i-th OpResult. - if (dspOp.isDpsInit(&opOperand)) { - return {AliasingValue(dspOp.getTiedOpResult(&opOperand) /*result*/, - BufferRelation::Equivalent, - /*isDefinite=*/false)}; - } - return {}; - } - - bufferization::BufferRelation - bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return bufferization::BufferRelation::Equivalent; - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { return bufferizeLinalgExtOp(