From 7ca85e1088825b8a296e00533c6aeb6936090f37 Mon Sep 17 00:00:00 2001 From: nouman-10x Date: Tue, 11 Mar 2025 20:52:19 +0500 Subject: [PATCH 1/7] [Pass] Add matmul-to-mmt4d Signed-off-by: nouman-10x --- lib/Transform/CMakeLists.txt | 3 +- lib/Transform/Linalg/CMakeLists.txt | 8 ++ lib/Transform/Linalg/MatmulToMmt4d.cpp | 111 +++++++++++++++++++++++++ lib/Transform/Linalg/MatmulToMmt4d.h | 25 ++++++ tests/matmul_to_mmt4d.mlir | 7 ++ tools/CMakeLists.txt | 1 + tools/dummy-opt.cpp | 2 + 7 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 lib/Transform/Linalg/CMakeLists.txt create mode 100644 lib/Transform/Linalg/MatmulToMmt4d.cpp create mode 100644 lib/Transform/Linalg/MatmulToMmt4d.h create mode 100644 tests/matmul_to_mmt4d.mlir diff --git a/lib/Transform/CMakeLists.txt b/lib/Transform/CMakeLists.txt index 9caa085..28ac34e 100644 --- a/lib/Transform/CMakeLists.txt +++ b/lib/Transform/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Affine) -add_subdirectory(Arith) \ No newline at end of file +add_subdirectory(Arith) +add_subdirectory(Linalg) \ No newline at end of file diff --git a/lib/Transform/Linalg/CMakeLists.txt b/lib/Transform/Linalg/CMakeLists.txt new file mode 100644 index 0000000..af6fb57 --- /dev/null +++ b/lib/Transform/Linalg/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect_library(MatmulToMmt4d + MatmulToMmt4d.h + MatmulToMmt4d.cpp + + ${PROJECT_SOURCE_DIR}/lib/Transform/Linalg/ + ADDITIONAL_HEADER_DIRS + LINK_LIBS PUBLIC +) \ No newline at end of file diff --git a/lib/Transform/Linalg/MatmulToMmt4d.cpp b/lib/Transform/Linalg/MatmulToMmt4d.cpp new file mode 100644 index 0000000..ca96081 --- /dev/null +++ b/lib/Transform/Linalg/MatmulToMmt4d.cpp @@ -0,0 +1,111 @@ +#include "lib/Transform/Linalg/MatmulToMmt4d.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir { + namespace dummy { + + struct Matmul : public OpRewritePattern { + Matmul(mlir::MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter &rewriter) const override { + // llvm::outs() << "Matched a MatmulOp\n"; + int64_t M0 = 512; + int64_t N0 = 32; + int64_t K0 = 32; + + auto inputs = op.getDpsInputOperands(); + auto outputs = op.getDpsInits(); + + auto lhsType = cast(inputs[0]->get().getType()); + auto rhsType = cast(inputs[1]->get().getType()); + auto resultType = cast(outputs[0].getType()); + + // auto lhsShape = lhsType.getShape(); + // auto rhsShape = rhsType.getShape(); + // auto resultShape = resultType.getShape(); + + // llvm::outs() << "LHS Shape: " << lhsShape[0] << "x" << lhsShape[1] << '\n'; + // llvm::outs() << "RHS Shape: " << rhsShape[0] << "x" << rhsShape[1] << '\n'; + // llvm::outs() << "Result Shape: " << resultShape[0] << "x" << resultShape[1] << '\n'; + + + // for (OpFoldResult result : final) { + // // Process each OpFoldResult + // if (auto val = result.dyn_cast()) { + // llvm::outs() << "Value: " << val << "\n"; + // } else if (auto attr = result.dyn_cast()) { + // llvm::outs() << "Attribute: " << attr << "\n"; + // } + // } + + Location loc = op.getLoc(); + Value paddingValue = rewriter.create(loc, rewriter.getZeroAttr(lhsType.getElementType())); + + llvm::SmallVector lhsSourceDims = tensor::getMixedSizes(rewriter, loc, inputs[0]->get()); + llvm::SmallVector lhsTileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, K0})); + SmallVector lhsInnerDimsPos = {0, 1}; + SmallVector lhsResultDims = linalg::PackOp::getResultShape( + rewriter, loc, lhsSourceDims, lhsTileSizes, lhsInnerDimsPos, + lhsInnerDimsPos); + tensor::EmptyOp emptyOp0 = rewriter.create(loc, lhsResultDims, lhsType.getElementType()); + linalg::PackOp lhsPack = rewriter.create(loc, inputs[0]->get(), emptyOp0, lhsInnerDimsPos, lhsTileSizes, paddingValue, lhsInnerDimsPos); + + llvm::SmallVector rhsSourceDims = tensor::getMixedSizes(rewriter, loc, inputs[1]->get()); + llvm::SmallVector rhsTileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr({N0, K0})); + SmallVector rhsInnerDimsPos = {1, 0}; + SmallVector rhsResultDims = linalg::PackOp::getResultShape( + rewriter, loc, rhsSourceDims, rhsTileSizes, rhsInnerDimsPos, + rhsInnerDimsPos); + tensor::EmptyOp emptyOp1 = rewriter.create(loc, rhsResultDims, rhsType.getElementType()); + linalg::PackOp rhsPack = rewriter.create(loc, inputs[1]->get(), emptyOp1, rhsInnerDimsPos, rhsTileSizes, paddingValue, rhsInnerDimsPos); + + llvm::SmallVector resSourceDims = tensor::getMixedSizes(rewriter, loc, outputs[0]); + llvm::SmallVector resTileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, N0})); + SmallVector resInnerDimsPos = {0, 1}; + SmallVector resResultDims = linalg::PackOp::getResultShape( + rewriter, loc, resSourceDims, resTileSizes, resInnerDimsPos, + resInnerDimsPos); + tensor::EmptyOp emptyOp2 = rewriter.create(loc, resResultDims, resultType.getElementType()); + linalg::PackOp resPack = rewriter.create(loc, outputs[0], emptyOp2, resInnerDimsPos, resTileSizes, paddingValue, resInnerDimsPos); + + linalg::Mmt4DOp mmt4d = rewriter.create(loc, resPack.getResult().getType(), ValueRange{lhsPack->getResult(0), rhsPack->getResult(0)}, ValueRange{resPack->getResult(0)}); + + llvm::SmallVector mmt4dDims = tensor::getMixedSizes(rewriter, loc, mmt4d.getDpsInits()[0]); + tensor::EmptyOp emptyOp3 = rewriter.create(loc, resSourceDims, resultType.getElementType()); + linalg::UnPackOp unpack = rewriter.create(loc, mmt4d->getResult(0), emptyOp3, resInnerDimsPos, resTileSizes, resInnerDimsPos); + + rewriter.replaceAllOpUsesWith(op, unpack); + rewriter.eraseOp(op); + + // for (OpFoldResult result : mmt4dDims) { + // // Process each OpFoldResult + // if (auto val = result.dyn_cast()) { + // llvm::outs() << "Value: " << val << "\n"; + // } else if (auto attr = result.dyn_cast()) { + // llvm::outs() << "Attribute: " << attr << "\n"; + // } + // } + // llvm::outs() << "Hell0\n"; + return success(); + } + }; + + void MatmulToMmt4dPass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } + } +} \ No newline at end of file diff --git a/lib/Transform/Linalg/MatmulToMmt4d.h b/lib/Transform/Linalg/MatmulToMmt4d.h new file mode 100644 index 0000000..5ca5ae3 --- /dev/null +++ b/lib/Transform/Linalg/MatmulToMmt4d.h @@ -0,0 +1,25 @@ +#ifndef LIB_TRANSFORM_ARITH_MATMULTOMMT4D_H_ +#define LIB_TRANSFORM_ARITH_MATMULTOMMT4D_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace dummy { + +class MatmulToMmt4dPass + : public PassWrapper> { + private: + void runOnOperation() override; + + StringRef getArgument() const final { return "matmul-to-mmt4d"; } + + StringRef getDescription() const final { + return "Convert linalg.matmul to linalg.mmt4d"; + } +}; + +} // namespace dummy +} // namespace mlir + +#endif // LIB_TRANSFORM_ARITH_MATMULTOMMT4D_H_ \ No newline at end of file diff --git a/tests/matmul_to_mmt4d.mlir b/tests/matmul_to_mmt4d.mlir new file mode 100644 index 0000000..291bedb --- /dev/null +++ b/tests/matmul_to_mmt4d.mlir @@ -0,0 +1,7 @@ +// RUN: dummy-opt --matmul-to-mmt4d %s | FileCheck %s + +// CHECK-LABEL: matmul_f32 +func.func @matmul_f32(%lhs: tensor, %rhs: tensor, %acc: tensor) -> tensor { + %result = linalg.matmul ins(%lhs, %rhs: tensor, tensor) outs(%acc: tensor) -> tensor + return %result: tensor +} \ No newline at end of file diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index dfaa359..ef8b803 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -8,6 +8,7 @@ set (LIBS ${conversion_libs} MLIRPoly10x AffineFullUnroll + MatmulToMmt4d MulToAdd MLIROptLib MLIRPass diff --git a/tools/dummy-opt.cpp b/tools/dummy-opt.cpp index c357085..53ffbdb 100644 --- a/tools/dummy-opt.cpp +++ b/tools/dummy-opt.cpp @@ -6,6 +6,7 @@ #include "lib/Transform/Affine/Passes.h" #include "lib/Transform/Arith/MulToAdd.h" +#include "lib/Transform/Linalg/MatmulToMmt4d.h" #include "lib/Dialect/Poly10x/Poly10xDialect.h" @@ -23,6 +24,7 @@ int main(int argc, char **argv) { mlir::dummy::registerAffinePasses(); // register hand-authored passes mlir::PassRegistration(); + mlir::PassRegistration(); return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Dummy Pass Driver", registry)); From c852a42c5b7fb243ec79e132522f254de5502a0c Mon Sep 17 00:00:00 2001 From: nouman-10x Date: Tue, 11 Mar 2025 20:57:00 +0500 Subject: [PATCH 2/7] Formatted files Signed-off-by: nouman-10x --- lib/Transform/Linalg/MatmulToMmt4d.cpp | 218 ++++++++++++++----------- 1 file changed, 123 insertions(+), 95 deletions(-) diff --git a/lib/Transform/Linalg/MatmulToMmt4d.cpp b/lib/Transform/Linalg/MatmulToMmt4d.cpp index ca96081..22f8b89 100644 --- a/lib/Transform/Linalg/MatmulToMmt4d.cpp +++ b/lib/Transform/Linalg/MatmulToMmt4d.cpp @@ -1,6 +1,6 @@ #include "lib/Transform/Linalg/MatmulToMmt4d.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" @@ -14,98 +14,126 @@ #include namespace mlir { - namespace dummy { - - struct Matmul : public OpRewritePattern { - Matmul(mlir::MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) {} - - LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter &rewriter) const override { - // llvm::outs() << "Matched a MatmulOp\n"; - int64_t M0 = 512; - int64_t N0 = 32; - int64_t K0 = 32; - - auto inputs = op.getDpsInputOperands(); - auto outputs = op.getDpsInits(); - - auto lhsType = cast(inputs[0]->get().getType()); - auto rhsType = cast(inputs[1]->get().getType()); - auto resultType = cast(outputs[0].getType()); - - // auto lhsShape = lhsType.getShape(); - // auto rhsShape = rhsType.getShape(); - // auto resultShape = resultType.getShape(); - - // llvm::outs() << "LHS Shape: " << lhsShape[0] << "x" << lhsShape[1] << '\n'; - // llvm::outs() << "RHS Shape: " << rhsShape[0] << "x" << rhsShape[1] << '\n'; - // llvm::outs() << "Result Shape: " << resultShape[0] << "x" << resultShape[1] << '\n'; - - - // for (OpFoldResult result : final) { - // // Process each OpFoldResult - // if (auto val = result.dyn_cast()) { - // llvm::outs() << "Value: " << val << "\n"; - // } else if (auto attr = result.dyn_cast()) { - // llvm::outs() << "Attribute: " << attr << "\n"; - // } - // } - - Location loc = op.getLoc(); - Value paddingValue = rewriter.create(loc, rewriter.getZeroAttr(lhsType.getElementType())); - - llvm::SmallVector lhsSourceDims = tensor::getMixedSizes(rewriter, loc, inputs[0]->get()); - llvm::SmallVector lhsTileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, K0})); - SmallVector lhsInnerDimsPos = {0, 1}; - SmallVector lhsResultDims = linalg::PackOp::getResultShape( - rewriter, loc, lhsSourceDims, lhsTileSizes, lhsInnerDimsPos, - lhsInnerDimsPos); - tensor::EmptyOp emptyOp0 = rewriter.create(loc, lhsResultDims, lhsType.getElementType()); - linalg::PackOp lhsPack = rewriter.create(loc, inputs[0]->get(), emptyOp0, lhsInnerDimsPos, lhsTileSizes, paddingValue, lhsInnerDimsPos); - - llvm::SmallVector rhsSourceDims = tensor::getMixedSizes(rewriter, loc, inputs[1]->get()); - llvm::SmallVector rhsTileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr({N0, K0})); - SmallVector rhsInnerDimsPos = {1, 0}; - SmallVector rhsResultDims = linalg::PackOp::getResultShape( - rewriter, loc, rhsSourceDims, rhsTileSizes, rhsInnerDimsPos, - rhsInnerDimsPos); - tensor::EmptyOp emptyOp1 = rewriter.create(loc, rhsResultDims, rhsType.getElementType()); - linalg::PackOp rhsPack = rewriter.create(loc, inputs[1]->get(), emptyOp1, rhsInnerDimsPos, rhsTileSizes, paddingValue, rhsInnerDimsPos); - - llvm::SmallVector resSourceDims = tensor::getMixedSizes(rewriter, loc, outputs[0]); - llvm::SmallVector resTileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, N0})); - SmallVector resInnerDimsPos = {0, 1}; - SmallVector resResultDims = linalg::PackOp::getResultShape( - rewriter, loc, resSourceDims, resTileSizes, resInnerDimsPos, - resInnerDimsPos); - tensor::EmptyOp emptyOp2 = rewriter.create(loc, resResultDims, resultType.getElementType()); - linalg::PackOp resPack = rewriter.create(loc, outputs[0], emptyOp2, resInnerDimsPos, resTileSizes, paddingValue, resInnerDimsPos); - - linalg::Mmt4DOp mmt4d = rewriter.create(loc, resPack.getResult().getType(), ValueRange{lhsPack->getResult(0), rhsPack->getResult(0)}, ValueRange{resPack->getResult(0)}); - - llvm::SmallVector mmt4dDims = tensor::getMixedSizes(rewriter, loc, mmt4d.getDpsInits()[0]); - tensor::EmptyOp emptyOp3 = rewriter.create(loc, resSourceDims, resultType.getElementType()); - linalg::UnPackOp unpack = rewriter.create(loc, mmt4d->getResult(0), emptyOp3, resInnerDimsPos, resTileSizes, resInnerDimsPos); - - rewriter.replaceAllOpUsesWith(op, unpack); - rewriter.eraseOp(op); - - // for (OpFoldResult result : mmt4dDims) { - // // Process each OpFoldResult - // if (auto val = result.dyn_cast()) { - // llvm::outs() << "Value: " << val << "\n"; - // } else if (auto attr = result.dyn_cast()) { - // llvm::outs() << "Attribute: " << attr << "\n"; - // } - // } - // llvm::outs() << "Hell0\n"; - return success(); - } - }; - - void MatmulToMmt4dPass::runOnOperation() { - mlir::RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); - } +namespace dummy { + +struct Matmul : public OpRewritePattern { + Matmul(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp op, + PatternRewriter &rewriter) const override { + // llvm::outs() << "Matched a MatmulOp\n"; + int64_t M0 = 512; + int64_t N0 = 32; + int64_t K0 = 32; + + auto inputs = op.getDpsInputOperands(); + auto outputs = op.getDpsInits(); + + auto lhsType = cast(inputs[0]->get().getType()); + auto rhsType = cast(inputs[1]->get().getType()); + auto resultType = cast(outputs[0].getType()); + + // auto lhsShape = lhsType.getShape(); + // auto rhsShape = rhsType.getShape(); + // auto resultShape = resultType.getShape(); + + // llvm::outs() << "LHS Shape: " << lhsShape[0] << "x" << lhsShape[1] << + // '\n'; llvm::outs() << "RHS Shape: " << rhsShape[0] << "x" << + // rhsShape[1] << '\n'; llvm::outs() << "Result Shape: " << + // resultShape[0] << "x" << resultShape[1] << '\n'; + + // for (OpFoldResult result : final) { + // // Process each OpFoldResult + // if (auto val = result.dyn_cast()) { + // llvm::outs() << "Value: " << val << "\n"; + // } else if (auto attr = result.dyn_cast()) { + // llvm::outs() << "Attribute: " << attr << "\n"; + // } + // } + + Location loc = op.getLoc(); + Value paddingValue = rewriter.create( + loc, rewriter.getZeroAttr(lhsType.getElementType())); + + llvm::SmallVector lhsSourceDims = + tensor::getMixedSizes(rewriter, loc, inputs[0]->get()); + llvm::SmallVector lhsTileSizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, K0})); + SmallVector lhsInnerDimsPos = {0, 1}; + SmallVector lhsResultDims = + linalg::PackOp::getResultShape(rewriter, loc, lhsSourceDims, + lhsTileSizes, lhsInnerDimsPos, + lhsInnerDimsPos); + tensor::EmptyOp emptyOp0 = rewriter.create( + loc, lhsResultDims, lhsType.getElementType()); + linalg::PackOp lhsPack = rewriter.create( + loc, inputs[0]->get(), emptyOp0, lhsInnerDimsPos, lhsTileSizes, + paddingValue, lhsInnerDimsPos); + + llvm::SmallVector rhsSourceDims = + tensor::getMixedSizes(rewriter, loc, inputs[1]->get()); + llvm::SmallVector rhsTileSizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr({N0, K0})); + SmallVector rhsInnerDimsPos = {1, 0}; + SmallVector rhsResultDims = + linalg::PackOp::getResultShape(rewriter, loc, rhsSourceDims, + rhsTileSizes, rhsInnerDimsPos, + rhsInnerDimsPos); + tensor::EmptyOp emptyOp1 = rewriter.create( + loc, rhsResultDims, rhsType.getElementType()); + linalg::PackOp rhsPack = rewriter.create( + loc, inputs[1]->get(), emptyOp1, rhsInnerDimsPos, rhsTileSizes, + paddingValue, rhsInnerDimsPos); + + llvm::SmallVector resSourceDims = + tensor::getMixedSizes(rewriter, loc, outputs[0]); + llvm::SmallVector resTileSizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, N0})); + SmallVector resInnerDimsPos = {0, 1}; + SmallVector resResultDims = + linalg::PackOp::getResultShape(rewriter, loc, resSourceDims, + resTileSizes, resInnerDimsPos, + resInnerDimsPos); + tensor::EmptyOp emptyOp2 = rewriter.create( + loc, resResultDims, resultType.getElementType()); + linalg::PackOp resPack = rewriter.create( + loc, outputs[0], emptyOp2, resInnerDimsPos, resTileSizes, + paddingValue, resInnerDimsPos); + + linalg::Mmt4DOp mmt4d = rewriter.create( + loc, resPack.getResult().getType(), + ValueRange{lhsPack->getResult(0), rhsPack->getResult(0)}, + ValueRange{resPack->getResult(0)}); + + llvm::SmallVector mmt4dDims = + tensor::getMixedSizes(rewriter, loc, mmt4d.getDpsInits()[0]); + tensor::EmptyOp emptyOp3 = rewriter.create( + loc, resSourceDims, resultType.getElementType()); + linalg::UnPackOp unpack = rewriter.create( + loc, mmt4d->getResult(0), emptyOp3, resInnerDimsPos, resTileSizes, + resInnerDimsPos); + + rewriter.replaceAllOpUsesWith(op, unpack); + rewriter.eraseOp(op); + + // for (OpFoldResult result : mmt4dDims) { + // // Process each OpFoldResult + // if (auto val = result.dyn_cast()) { + // llvm::outs() << "Value: " << val << "\n"; + // } else if (auto attr = result.dyn_cast()) { + // llvm::outs() << "Attribute: " << attr << "\n"; + // } + // } + // llvm::outs() << "Hell0\n"; + return success(); } -} \ No newline at end of file +}; + +void MatmulToMmt4dPass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); +} +} // namespace dummy +} // namespace mlir \ No newline at end of file From 22d58afbffc8e0705aa5e9d6c8d3326a67406464 Mon Sep 17 00:00:00 2001 From: nouman-10x Date: Wed, 12 Mar 2025 20:20:25 +0500 Subject: [PATCH 3/7] Commented code and added TODOs Signed-off-by: nouman-10x --- lib/Transform/Linalg/MatmulToMmt4d.cpp | 55 +++++++++----------------- tests/matmul_to_mmt4d.mlir | 2 +- 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/lib/Transform/Linalg/MatmulToMmt4d.cpp b/lib/Transform/Linalg/MatmulToMmt4d.cpp index 22f8b89..ab1ca62 100644 --- a/lib/Transform/Linalg/MatmulToMmt4d.cpp +++ b/lib/Transform/Linalg/MatmulToMmt4d.cpp @@ -22,45 +22,34 @@ struct Matmul : public OpRewritePattern { LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter &rewriter) const override { - // llvm::outs() << "Matched a MatmulOp\n"; - int64_t M0 = 512; + // TODO: Change these to command-line arguments + int64_t M0 = 32; int64_t N0 = 32; int64_t K0 = 32; - + + // DPS here means Destination Passing Style + // retrieves the input operands auto inputs = op.getDpsInputOperands(); + // retrieves the DPS accumulator/init auto outputs = op.getDpsInits(); - + + // gets the type of given tensor by casting it to RankedTensorType auto lhsType = cast(inputs[0]->get().getType()); auto rhsType = cast(inputs[1]->get().getType()); auto resultType = cast(outputs[0].getType()); - // auto lhsShape = lhsType.getShape(); - // auto rhsShape = rhsType.getShape(); - // auto resultShape = resultType.getShape(); - - // llvm::outs() << "LHS Shape: " << lhsShape[0] << "x" << lhsShape[1] << - // '\n'; llvm::outs() << "RHS Shape: " << rhsShape[0] << "x" << - // rhsShape[1] << '\n'; llvm::outs() << "Result Shape: " << - // resultShape[0] << "x" << resultShape[1] << '\n'; - - // for (OpFoldResult result : final) { - // // Process each OpFoldResult - // if (auto val = result.dyn_cast()) { - // llvm::outs() << "Value: " << val << "\n"; - // } else if (auto attr = result.dyn_cast()) { - // llvm::outs() << "Attribute: " << attr << "\n"; - // } - // } - Location loc = op.getLoc(); Value paddingValue = rewriter.create( loc, rewriter.getZeroAttr(lhsType.getElementType())); - + + // returns the dimension of given tensor value llvm::SmallVector lhsSourceDims = tensor::getMixedSizes(rewriter, loc, inputs[0]->get()); + // returns the ArrayAttr as a OpFoldResult llvm::SmallVector lhsTileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr({M0, K0})); SmallVector lhsInnerDimsPos = {0, 1}; + // returns the shape that the pack result would result in SmallVector lhsResultDims = linalg::PackOp::getResultShape(rewriter, loc, lhsSourceDims, lhsTileSizes, lhsInnerDimsPos, @@ -100,7 +89,8 @@ struct Matmul : public OpRewritePattern { linalg::PackOp resPack = rewriter.create( loc, outputs[0], emptyOp2, resInnerDimsPos, resTileSizes, paddingValue, resInnerDimsPos); - + + // TODO: What is ValueRange? linalg::Mmt4DOp mmt4d = rewriter.create( loc, resPack.getResult().getType(), ValueRange{lhsPack->getResult(0), rhsPack->getResult(0)}, @@ -113,19 +103,12 @@ struct Matmul : public OpRewritePattern { linalg::UnPackOp unpack = rewriter.create( loc, mmt4d->getResult(0), emptyOp3, resInnerDimsPos, resTileSizes, resInnerDimsPos); - + + // This repalces the uses of MatmulOp with UnpackOp rewriter.replaceAllOpUsesWith(op, unpack); + // erases the MatmulOp rewriter.eraseOp(op); - - // for (OpFoldResult result : mmt4dDims) { - // // Process each OpFoldResult - // if (auto val = result.dyn_cast()) { - // llvm::outs() << "Value: " << val << "\n"; - // } else if (auto attr = result.dyn_cast()) { - // llvm::outs() << "Attribute: " << attr << "\n"; - // } - // } - // llvm::outs() << "Hell0\n"; + return success(); } }; @@ -136,4 +119,4 @@ void MatmulToMmt4dPass::runOnOperation() { (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } } // namespace dummy -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/tests/matmul_to_mmt4d.mlir b/tests/matmul_to_mmt4d.mlir index 291bedb..444108e 100644 --- a/tests/matmul_to_mmt4d.mlir +++ b/tests/matmul_to_mmt4d.mlir @@ -4,4 +4,4 @@ func.func @matmul_f32(%lhs: tensor, %rhs: tensor, %acc: tensor) -> tensor { %result = linalg.matmul ins(%lhs, %rhs: tensor, tensor) outs(%acc: tensor) -> tensor return %result: tensor -} \ No newline at end of file +} From 261eca8875ead37ed977c85f21962318d5e9fb17 Mon Sep 17 00:00:00 2001 From: nouman-10x Date: Wed, 12 Mar 2025 20:25:46 +0500 Subject: [PATCH 4/7] Updated the matmul_to_mmt4d.mlir test with new CHECK statements Signed-off-by: nouman-10x --- lib/Transform/Linalg/MatmulToMmt4d.cpp | 12 ++++++------ tests/matmul_to_mmt4d.mlir | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lib/Transform/Linalg/MatmulToMmt4d.cpp b/lib/Transform/Linalg/MatmulToMmt4d.cpp index ab1ca62..7e6fb5c 100644 --- a/lib/Transform/Linalg/MatmulToMmt4d.cpp +++ b/lib/Transform/Linalg/MatmulToMmt4d.cpp @@ -26,13 +26,13 @@ struct Matmul : public OpRewritePattern { int64_t M0 = 32; int64_t N0 = 32; int64_t K0 = 32; - + // DPS here means Destination Passing Style // retrieves the input operands auto inputs = op.getDpsInputOperands(); // retrieves the DPS accumulator/init auto outputs = op.getDpsInits(); - + // gets the type of given tensor by casting it to RankedTensorType auto lhsType = cast(inputs[0]->get().getType()); auto rhsType = cast(inputs[1]->get().getType()); @@ -41,7 +41,7 @@ struct Matmul : public OpRewritePattern { Location loc = op.getLoc(); Value paddingValue = rewriter.create( loc, rewriter.getZeroAttr(lhsType.getElementType())); - + // returns the dimension of given tensor value llvm::SmallVector lhsSourceDims = tensor::getMixedSizes(rewriter, loc, inputs[0]->get()); @@ -89,7 +89,7 @@ struct Matmul : public OpRewritePattern { linalg::PackOp resPack = rewriter.create( loc, outputs[0], emptyOp2, resInnerDimsPos, resTileSizes, paddingValue, resInnerDimsPos); - + // TODO: What is ValueRange? linalg::Mmt4DOp mmt4d = rewriter.create( loc, resPack.getResult().getType(), @@ -103,12 +103,12 @@ struct Matmul : public OpRewritePattern { linalg::UnPackOp unpack = rewriter.create( loc, mmt4d->getResult(0), emptyOp3, resInnerDimsPos, resTileSizes, resInnerDimsPos); - + // This repalces the uses of MatmulOp with UnpackOp rewriter.replaceAllOpUsesWith(op, unpack); // erases the MatmulOp rewriter.eraseOp(op); - + return success(); } }; diff --git a/tests/matmul_to_mmt4d.mlir b/tests/matmul_to_mmt4d.mlir index 444108e..1d80289 100644 --- a/tests/matmul_to_mmt4d.mlir +++ b/tests/matmul_to_mmt4d.mlir @@ -5,3 +5,10 @@ func.func @matmul_f32(%lhs: tensor, %rhs: tensor, %acc: tensor %result = linalg.matmul ins(%lhs, %rhs: tensor, tensor) outs(%acc: tensor) -> tensor return %result: tensor } + +// CHECK: linalg.pack +// CHECK: linalg.pack +// CHECK: linalg.pack +// CHECK-NOT: linalg.matmul +// CHECK: linalg.mmt4d +// CHECK: linalg.unpack From ed6a5ec5796c4fe0bb87a5e5db24ea44dd869df1 Mon Sep 17 00:00:00 2001 From: nouman-10x Date: Wed, 12 Mar 2025 20:26:23 +0500 Subject: [PATCH 5/7] Removed newline Signed-off-by: nouman-10x --- tests/matmul_to_mmt4d.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/matmul_to_mmt4d.mlir b/tests/matmul_to_mmt4d.mlir index 1d80289..29f65b3 100644 --- a/tests/matmul_to_mmt4d.mlir +++ b/tests/matmul_to_mmt4d.mlir @@ -11,4 +11,4 @@ func.func @matmul_f32(%lhs: tensor, %rhs: tensor, %acc: tensor // CHECK: linalg.pack // CHECK-NOT: linalg.matmul // CHECK: linalg.mmt4d -// CHECK: linalg.unpack +// CHECK: linalg.unpack \ No newline at end of file From f19b8911771e5f76006390a166b32db6efaf083d Mon Sep 17 00:00:00 2001 From: nouman-10x Date: Wed, 12 Mar 2025 20:39:20 +0500 Subject: [PATCH 6/7] Added tile sizes as command-line arguments Signed-off-by: nouman-10x --- lib/Transform/Linalg/MatmulToMmt4d.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/lib/Transform/Linalg/MatmulToMmt4d.cpp b/lib/Transform/Linalg/MatmulToMmt4d.cpp index 7e6fb5c..1374947 100644 --- a/lib/Transform/Linalg/MatmulToMmt4d.cpp +++ b/lib/Transform/Linalg/MatmulToMmt4d.cpp @@ -16,16 +16,27 @@ namespace mlir { namespace dummy { +static llvm::cl::opt + clMTile("dummy-m-tile", llvm::cl::desc("Inner tile size of M dimension"), + llvm::cl::init(32)); + +static llvm::cl::opt + clNTile("dummy-n-tile", llvm::cl::desc("Inner tile size of N dimension"), + llvm::cl::init(32)); + +static llvm::cl::opt + clKTile("dummy-k-tile", llvm::cl::desc("Inner tile size of K dimension"), + llvm::cl::init(32)); + struct Matmul : public OpRewritePattern { Matmul(mlir::MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter &rewriter) const override { - // TODO: Change these to command-line arguments - int64_t M0 = 32; - int64_t N0 = 32; - int64_t K0 = 32; + int64_t M0 = clMTile; + int64_t N0 = clNTile; + int64_t K0 = clKTile; // DPS here means Destination Passing Style // retrieves the input operands From da88ecc49ee9bf10eba76d5753759b9701ad94c3 Mon Sep 17 00:00:00 2001 From: nouman-10x Date: Wed, 12 Mar 2025 23:39:33 +0500 Subject: [PATCH 7/7] Added note about ValueRange Signed-off-by: nouman-10x --- lib/Transform/Linalg/MatmulToMmt4d.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Transform/Linalg/MatmulToMmt4d.cpp b/lib/Transform/Linalg/MatmulToMmt4d.cpp index 1374947..16297c1 100644 --- a/lib/Transform/Linalg/MatmulToMmt4d.cpp +++ b/lib/Transform/Linalg/MatmulToMmt4d.cpp @@ -101,7 +101,8 @@ struct Matmul : public OpRewritePattern { loc, outputs[0], emptyOp2, resInnerDimsPos, resTileSizes, paddingValue, resInnerDimsPos); - // TODO: What is ValueRange? + // ValueRange is just a view over the underlying data + // It does not hold the actual ownership of the data linalg::Mmt4DOp mmt4d = rewriter.create( loc, resPack.getResult().getType(), ValueRange{lhsPack->getResult(0), rhsPack->getResult(0)},