Skip to content

Commit

Permalink
[LinalgExt][Bufferization] Clean-up bufferization pass for LinalgExt (i…
Browse files Browse the repository at this point in the history
…ree-org#15040)

-- This commit cleans up bufferization pass for LinalgExt ops by making
   use of `DstBufferizableOpInterfaceExternalModel`.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma authored Sep 29, 2023
1 parent 1ba5e37 commit a53153a
Showing 1 changed file with 6 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
Expand Down Expand Up @@ -344,59 +345,16 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter,
/// a new op that operates entirely on memrefs.
template <typename OpTy>
struct LinalgExtOpInterface
: public BufferizableOpInterface::ExternalModel<LinalgExtOpInterface<OpTy>,
OpTy> {
: public bufferization::DstBufferizableOpInterfaceExternalModel<
LinalgExtOpInterface<OpTy>, OpTy> {

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// TODO: Implement payloadUsesValueFromOperand for individual ops. There
// are a limited number of LinalgExt ops, so we hardcode them here. We don't
// expect to add more LinalgExt ops.
if (!cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand))
return true;
// TODO: Revisit this for Scatter/ReverseOp. We can then get rid of
// `bufferizesToMemoryRead` completely.
return !isa<IREE::LinalgExt::ScatterOp, IREE::LinalgExt::ReverseOp>(op);
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Operand is written to if it has an aliasing OpResult.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
return !bufferizableOp.getAliasingValues(opOperand, state)
.getAliases()
.empty();
}

bufferization::AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
// The i-th OpResult may alias with the i-th "out" tensor.
return {AliasingOpOperand(
&cast<DestinationStyleOpInterface>(op).getDpsInitsMutable()[resultNum],
BufferRelation::Equivalent,
/*isDefinite=*/false)};
}

bufferization::AliasingValueList
getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto dspOp = cast<DestinationStyleOpInterface>(op);

// The i-th "out" tensor may alias with the i-th OpResult.
if (dspOp.isDpsInit(&opOperand)) {
return {AliasingValue(dspOp.getTiedOpResult(&opOperand) /*result*/,
BufferRelation::Equivalent,
/*isDefinite=*/false)};
}
return {};
}

bufferization::BufferRelation
bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return bufferization::BufferRelation::Equivalent;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
return bufferizeLinalgExtOp(
Expand Down

0 comments on commit a53153a

Please sign in to comment.