diff --git a/lib/Transform/CMakeLists.txt b/lib/Transform/CMakeLists.txt index 9caa085..28ac34e 100644 --- a/lib/Transform/CMakeLists.txt +++ b/lib/Transform/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Affine) -add_subdirectory(Arith) \ No newline at end of file +add_subdirectory(Arith) +add_subdirectory(Linalg) \ No newline at end of file diff --git a/lib/Transform/Linalg/CMakeLists.txt b/lib/Transform/Linalg/CMakeLists.txt new file mode 100644 index 0000000..af6fb57 --- /dev/null +++ b/lib/Transform/Linalg/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect_library(MatmulToMmt4d + MatmulToMmt4d.h + MatmulToMmt4d.cpp + + ${PROJECT_SOURCE_DIR}/lib/Transform/Linalg/ + ADDITIONAL_HEADER_DIRS + LINK_LIBS PUBLIC +) \ No newline at end of file diff --git a/lib/Transform/Linalg/MatmulToMmt4d.cpp b/lib/Transform/Linalg/MatmulToMmt4d.cpp new file mode 100644 index 0000000..16297c1 --- /dev/null +++ b/lib/Transform/Linalg/MatmulToMmt4d.cpp @@ -0,0 +1,134 @@ +#include "lib/Transform/Linalg/MatmulToMmt4d.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir { +namespace dummy { + +static llvm::cl::opt + clMTile("dummy-m-tile", llvm::cl::desc("Inner tile size of M dimension"), + llvm::cl::init(32)); + +static llvm::cl::opt + clNTile("dummy-n-tile", llvm::cl::desc("Inner tile size of N dimension"), + llvm::cl::init(32)); + +static llvm::cl::opt + clKTile("dummy-k-tile", llvm::cl::desc("Inner tile size of K dimension"), + llvm::cl::init(32)); + +struct Matmul : public OpRewritePattern { + Matmul(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp op, + PatternRewriter &rewriter) const override { + int64_t M0 = clMTile; + int64_t N0 = clNTile; + int64_t K0 = clKTile; + + // DPS here means Destination Passing Style + // retrieves the input operands + auto inputs = op.getDpsInputOperands(); + // retrieves the DPS accumulator/init + auto outputs = op.getDpsInits(); + + // gets the type of given tensor by casting it to RankedTensorType + auto lhsType = cast(inputs[0]->get().getType()); + auto rhsType = cast(inputs[1]->get().getType()); + auto resultType = cast(outputs[0].getType()); + + Location loc = op.getLoc(); + Value paddingValue = rewriter.create( + loc, rewriter.getZeroAttr(lhsType.getElementType())); + + // returns the dimension of given tensor value + llvm::SmallVector lhsSourceDims = + tensor::getMixedSizes(rewriter, loc, inputs[0]->get()); + // returns the ArrayAttr as a OpFoldResult + llvm::SmallVector lhsTileSizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, K0})); + SmallVector lhsInnerDimsPos = {0, 1}; + // returns the shape that the pack result would result in + SmallVector lhsResultDims = + linalg::PackOp::getResultShape(rewriter, loc, lhsSourceDims, + lhsTileSizes, lhsInnerDimsPos, + lhsInnerDimsPos); + tensor::EmptyOp emptyOp0 = rewriter.create( + loc, lhsResultDims, lhsType.getElementType()); + linalg::PackOp lhsPack = rewriter.create( + loc, inputs[0]->get(), emptyOp0, lhsInnerDimsPos, lhsTileSizes, + paddingValue, lhsInnerDimsPos); + + llvm::SmallVector rhsSourceDims = + tensor::getMixedSizes(rewriter, loc, inputs[1]->get()); + llvm::SmallVector rhsTileSizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr({N0, K0})); + SmallVector rhsInnerDimsPos = {1, 0}; + SmallVector rhsResultDims = + linalg::PackOp::getResultShape(rewriter, loc, rhsSourceDims, + rhsTileSizes, rhsInnerDimsPos, + rhsInnerDimsPos); + tensor::EmptyOp emptyOp1 = rewriter.create( + loc, rhsResultDims, rhsType.getElementType()); + linalg::PackOp rhsPack = rewriter.create( + loc, inputs[1]->get(), emptyOp1, rhsInnerDimsPos, rhsTileSizes, + paddingValue, rhsInnerDimsPos); + + llvm::SmallVector resSourceDims = + tensor::getMixedSizes(rewriter, loc, outputs[0]); + llvm::SmallVector resTileSizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, N0})); + SmallVector resInnerDimsPos = {0, 1}; + SmallVector resResultDims = + linalg::PackOp::getResultShape(rewriter, loc, resSourceDims, + resTileSizes, resInnerDimsPos, + resInnerDimsPos); + tensor::EmptyOp emptyOp2 = rewriter.create( + loc, resResultDims, resultType.getElementType()); + linalg::PackOp resPack = rewriter.create( + loc, outputs[0], emptyOp2, resInnerDimsPos, resTileSizes, + paddingValue, resInnerDimsPos); + + // ValueRange is just a view over the underlying data + // It does not hold the actual ownership of the data + linalg::Mmt4DOp mmt4d = rewriter.create( + loc, resPack.getResult().getType(), + ValueRange{lhsPack->getResult(0), rhsPack->getResult(0)}, + ValueRange{resPack->getResult(0)}); + + llvm::SmallVector mmt4dDims = + tensor::getMixedSizes(rewriter, loc, mmt4d.getDpsInits()[0]); + tensor::EmptyOp emptyOp3 = rewriter.create( + loc, resSourceDims, resultType.getElementType()); + linalg::UnPackOp unpack = rewriter.create( + loc, mmt4d->getResult(0), emptyOp3, resInnerDimsPos, resTileSizes, + resInnerDimsPos); + + // This repalces the uses of MatmulOp with UnpackOp + rewriter.replaceAllOpUsesWith(op, unpack); + // erases the MatmulOp + rewriter.eraseOp(op); + + return success(); + } +}; + +void MatmulToMmt4dPass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); +} +} // namespace dummy +} // namespace mlir diff --git a/lib/Transform/Linalg/MatmulToMmt4d.h b/lib/Transform/Linalg/MatmulToMmt4d.h new file mode 100644 index 0000000..5ca5ae3 --- /dev/null +++ b/lib/Transform/Linalg/MatmulToMmt4d.h @@ -0,0 +1,25 @@ +#ifndef LIB_TRANSFORM_ARITH_MATMULTOMMT4D_H_ +#define LIB_TRANSFORM_ARITH_MATMULTOMMT4D_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace dummy { + +class MatmulToMmt4dPass + : public PassWrapper> { + private: + void runOnOperation() override; + + StringRef getArgument() const final { return "matmul-to-mmt4d"; } + + StringRef getDescription() const final { + return "Convert linalg.matmul to linalg.mmt4d"; + } +}; + +} // namespace dummy +} // namespace mlir + +#endif // LIB_TRANSFORM_ARITH_MATMULTOMMT4D_H_ \ No newline at end of file diff --git a/tests/matmul_to_mmt4d.mlir b/tests/matmul_to_mmt4d.mlir new file mode 100644 index 0000000..29f65b3 --- /dev/null +++ b/tests/matmul_to_mmt4d.mlir @@ -0,0 +1,14 @@ +// RUN: dummy-opt --matmul-to-mmt4d %s | FileCheck %s + +// CHECK-LABEL: matmul_f32 +func.func @matmul_f32(%lhs: tensor, %rhs: tensor, %acc: tensor) -> tensor { + %result = linalg.matmul ins(%lhs, %rhs: tensor, tensor) outs(%acc: tensor) -> tensor + return %result: tensor +} + +// CHECK: linalg.pack +// CHECK: linalg.pack +// CHECK: linalg.pack +// CHECK-NOT: linalg.matmul +// CHECK: linalg.mmt4d +// CHECK: linalg.unpack \ No newline at end of file diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index dfaa359..ef8b803 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -8,6 +8,7 @@ set (LIBS ${conversion_libs} MLIRPoly10x AffineFullUnroll + MatmulToMmt4d MulToAdd MLIROptLib MLIRPass diff --git a/tools/dummy-opt.cpp b/tools/dummy-opt.cpp index c357085..53ffbdb 100644 --- a/tools/dummy-opt.cpp +++ b/tools/dummy-opt.cpp @@ -6,6 +6,7 @@ #include "lib/Transform/Affine/Passes.h" #include "lib/Transform/Arith/MulToAdd.h" +#include "lib/Transform/Linalg/MatmulToMmt4d.h" #include "lib/Dialect/Poly10x/Poly10xDialect.h" @@ -23,6 +24,7 @@ int main(int argc, char **argv) { mlir::dummy::registerAffinePasses(); // register hand-authored passes mlir::PassRegistration(); + mlir::PassRegistration(); return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Dummy Pass Driver", registry));