From 5d90b11a2cac99af0c16313a9b78c7181a388559 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Wed, 5 Jun 2024 19:54:32 -0700 Subject: [PATCH] Loop tiling, shuffle and expansion passes --- include/TPP/PassBundles.td | 20 +- include/TPP/Passes.td | 39 +++ lib/TPP/DefaultPipeline.cpp | 31 +- lib/TPP/DefaultTppPasses.cpp | 33 ++- lib/TPP/PassBundles/LinalgLowering.cpp | 2 - .../PassBundles/LowLevelParallelization.cpp | 20 +- lib/TPP/Transforms/CMakeLists.txt | 3 + lib/TPP/Transforms/CombineXsmmPass.cpp | 20 +- lib/TPP/Transforms/LoopExpansion.cpp | 143 +++++++++ lib/TPP/Transforms/LoopInsertion.cpp | 276 ++++++++++++++++++ lib/TPP/Transforms/LoopShuffle.cpp | 128 ++++++++ .../pass-convert-gemm-to-parallel-tile.mlir | 18 +- .../pass-convert-mlp-to-parallel-tile.mlir | 24 +- 13 files changed, 715 insertions(+), 42 deletions(-) create mode 100644 lib/TPP/Transforms/LoopExpansion.cpp create mode 100644 lib/TPP/Transforms/LoopInsertion.cpp create mode 100644 lib/TPP/Transforms/LoopShuffle.cpp diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index 66cc514e4..688e0fff0 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -34,7 +34,15 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> { "bool", /*default=*/"false", "Skip all TPP transformations. Lower linalg directly to loops.">, ListOption<"parallelTaskGrid", "parallel-task-grid", - "unsigned", "Grid-sizes for parallel tasks."> + "unsigned", "Grid-sizes for parallel tasks.">, + ListOption<"tileShapeM", "M-tile-shape", "unsigned", + "Shape to reshape the M tensor into">, + ListOption<"tileShapeN", "N-tile-shape", "unsigned", + "Shape to reshape the N tensor into">, + ListOption<"shuffleOrder", "loop-shuffle-order", "unsigned", + "Shuffle order of scf for all loop surrounding brgemm op">, + Option<"outerParallelLoops", "num-outer-parallel", "unsigned", "0", + "Number of outer loops to be parallelized"> ]; } @@ -68,8 +76,14 @@ def LowLevelParallelization : Pass<"low-level-parallel", "ModuleOp"> { "xsmm::XsmmDialect", "LLVM::LLVMDialect"]; let options = [ - ListOption<"parallelTaskGrid", "parallel-task-grid", - "unsigned", "Grid-sizes for parallel tasks."> + ListOption<"tileShapeM", "M-tile-shape", "unsigned", + "Shape to reshape the M tensor into">, + ListOption<"tileShapeN", "N-tile-shape", "unsigned", + "Shape to reshape the N tensor into">, + ListOption<"shuffleOrder", "loop-shuffle-order", "unsigned", + "Shuffle order of scf for all loop surrounding brgemm op">, + Option<"outerParallelLoops", "num-outer-parallel", "unsigned","0", + "Number of outer loops to be parallelized"> ]; } diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 3f4de5935..69d0b73b4 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -429,6 +429,45 @@ def ConvertAddInplacePass: Pass<"linalg-convert-add-in-place", let dependentDialects = ["linalg::LinalgDialect"]; } +def LoopInsertionPass: Pass<"loop-insertion-pass">{ + let summary = "Insert loop around brgemm parallel op"; + let description = [{ + Insert nd parallel loop around brgemm parallel loop. + }]; + let dependentDialects = ["scf::SCFDialect" , "xsmm::XsmmDialect"]; + let options = [ + ListOption<"tileShapeM", "M-tile-shape", "unsigned", + "Shape to reshape the M tensor into">, + ListOption<"tileShapeN", "N-tile-shape", "unsigned", + "Shape to reshape the N tensor into"> + ]; +} + +def LoopExpansionPass: Pass<"loop-expansion-pass">{ + let summary = "Expand brgemm parallel op"; + let description = [{ + Expand nd parallel loop. + }]; + let options = [ + Option<"numOuterParallel", "num-outer-parallel", "unsigned", + "0", "Number of outer Parallel Loops"> + ]; + let dependentDialects = ["scf::SCFDialect"]; +} + +def LoopShufflePass: Pass<"loop-shuffle-pass">{ + let summary = "Shuffle brgemm parallel op"; + let description = [{ + Shuffle nd brgemm parallel loop. + }]; + let options = [ + ListOption<"shuffleOrder", "shuffle-order", "unsigned", + "Order to shuffle the parallel loop by"> + ]; + + let dependentDialects = ["scf::SCFDialect"]; +} + def TppRunnerWrapper : Pass<"tpp-runner-wrapper", "ModuleOp">{ let summary = "Create main function runner wrapper"; let description = [{ diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index e7b29b4c0..542224804 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -53,6 +53,24 @@ llvm::cl::list llvm::cl::list_init(SmallVector{2, 8}), llvm::cl::CommaSeparated); +llvm::cl::list tileShapeM("M-tile-shape", + llvm::cl::desc("Tile shape of M tensor"), + llvm::cl::CommaSeparated); + +llvm::cl::list tileShapeN("N-tile-shape", + llvm::cl::desc("Tile shape of N tensor"), + llvm::cl::CommaSeparated); + +llvm::cl::list shuffleOrder( + "loop-shuffle-order", + llvm::cl::desc("shuffle order of scf for all loop surrounding brgemm op"), + llvm::cl::CommaSeparated); + +llvm::cl::opt outerParallelLoops( + "num-outer-parallel", + llvm::cl::desc("Number of outer loops to be parallelized"), + llvm::cl::value_desc("int"), llvm::cl::init(0)); + namespace mlir { namespace tpp { #define GEN_PASS_DEF_DEFAULTPIPELINE @@ -124,9 +142,16 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend})); } else { // Apply the default preprocessing pass - DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, - parallelTaskGrid}; - pm.addPass(createDefaultTppPasses(tppDefaultOptions)); + if (!tileShapeM.empty() && !tileShapeN.empty()) { + DefaultTppPassesOptions tppDefaultOptions{ + linalgToLoops, parallelTaskGrid, tileShapeM, + tileShapeN, shuffleOrder, outerParallelLoops}; + pm.addPass(createDefaultTppPasses(tppDefaultOptions)); + } else { + DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, + parallelTaskGrid}; + pm.addPass(createDefaultTppPasses(tppDefaultOptions)); + } } if (print == PrintStage::Mid) diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index e768225ee..15ac8ef8f 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -14,6 +14,7 @@ #include "mlir/InitAllDialects.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" #include "TPP/Dialect/Check/BufferizableOpInterfaceImpl.h" #include "TPP/Dialect/Check/CheckDialect.h" @@ -101,13 +102,33 @@ struct DefaultTppPasses pm.addPass(createCleanup()); } - // Convert forAll to parallel loops should run after bufferization - // as scf.parallel does not handle tensor. - pm.addPass(createConvertForAllToParallelOp()); + // Low level parallelization passes. + if (!tileShapeM.empty() && !tileShapeN.empty()) { + LowLevelParallelizationOptions LowLevelParallelization( + LowLevelParallelizationOptions{tileShapeM, tileShapeN, shuffleOrder, + outerParallelLoops}); + pm.addPass(createLowLevelParallelization(LowLevelParallelization)); - // Low leve parallelization passes. - LowLevelParallelizationOptions LowLevelParallelization{parallelTaskGrid}; - pm.addPass(createLowLevelParallelization(LowLevelParallelization)); + // Convert forAll to parallel loops should run after bufferization + // as scf.parallel does not handle tensor. + pm.addPass(createConvertForAllToParallelOp()); + } else { + // FIXME remove as soon as the above code is fixed + pm.addPass(createConvertForAllToParallelOp()); + mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; + tilingOptions.tileSizes = parallelTaskGrid; + pm.addPass(createSCFParallelLoopTiling(tilingOptions)); + + pm.addNestedPass(createIntelAMXTileConfigInsertionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createLoopInvariantCodeMotionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createIntelAMXTileConfigHoistingPass()); + pm.addPass(createCombineXsmmOpPass()); + pm.addNestedPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createFoldXsmmFlags()); + pm.addPass(createVerifyXsmmCalls()); + } // Covert all local TPP-related dialects. pm.addPass(createLocalDialectsLowering()); diff --git a/lib/TPP/PassBundles/LinalgLowering.cpp b/lib/TPP/PassBundles/LinalgLowering.cpp index 0b863c934..2d4647eff 100644 --- a/lib/TPP/PassBundles/LinalgLowering.cpp +++ b/lib/TPP/PassBundles/LinalgLowering.cpp @@ -48,8 +48,6 @@ struct LinalgLowering : public tpp::impl::LinalgLoweringBase, private: void constructPipeline() override { pm.addPass(createConvertLinalgToXsmm()); - pm.addPass(createCombineXsmmOpPass()); - pm.addPass(createFoldXsmmFlags()); pm.addPass(createVerifyXsmmCalls()); } }; diff --git a/lib/TPP/PassBundles/LowLevelParallelization.cpp b/lib/TPP/PassBundles/LowLevelParallelization.cpp index a396c6cd0..045c90fbc 100644 --- a/lib/TPP/PassBundles/LowLevelParallelization.cpp +++ b/lib/TPP/PassBundles/LowLevelParallelization.cpp @@ -62,14 +62,28 @@ struct LowLevelParallelization // that they are hoisted out of loops. pm.addPass(createCleanup()); - mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; - tilingOptions.tileSizes = parallelTaskGrid; - pm.addPass(createSCFParallelLoopTiling(tilingOptions)); + mlir::tpp::LoopInsertionPassOptions loopInsertionPassOptions; + loopInsertionPassOptions.tileShapeM = tileShapeM; + loopInsertionPassOptions.tileShapeN = tileShapeN; + pm.addPass(createLoopInsertionPass(loopInsertionPassOptions)); + pm.addNestedPass(createCanonicalizerPass()); + + mlir::tpp::LoopShufflePassOptions loopShufflePassOptions; + loopShufflePassOptions.shuffleOrder = shuffleOrder; + pm.addPass(createLoopShufflePass(loopShufflePassOptions)); + + mlir::tpp::LoopExpansionPassOptions loopExpansionPassOptions; + loopExpansionPassOptions.numOuterParallel = outerParallelLoops; + pm.addPass(createLoopExpansionPass(loopExpansionPassOptions)); pm.addNestedPass(createIntelAMXTileConfigInsertionPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createLoopInvariantCodeMotionPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createIntelAMXTileConfigHoistingPass()); + pm.addPass(createCombineXsmmOpPass()); + pm.addNestedPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createFoldXsmmFlags()); + pm.addPass(createVerifyXsmmCalls()); } }; diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index fb7385b5c..7b57db486 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -21,6 +21,9 @@ add_mlir_library(TPPTransforms IntelAMXTileConfigHoisting.cpp LinalgConvertCompareSelectToMaximumfPass.cpp ConvertAddInplacePass.cpp + LoopInsertion.cpp + LoopExpansion.cpp + LoopShuffle.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/CombineXsmmPass.cpp b/lib/TPP/Transforms/CombineXsmmPass.cpp index 010db6ec0..6f1c8be99 100644 --- a/lib/TPP/Transforms/CombineXsmmPass.cpp +++ b/lib/TPP/Transforms/CombineXsmmPass.cpp @@ -126,19 +126,31 @@ struct CombineXsmmOp : public OpRewritePattern { // Replace and delete the old invokes and their dispatches rewriter.create(loc, dtype, invokeOperands); + assert(brgemmOp.use_empty()); rewriter.eraseOp(brgemmOp); - rewriter.eraseOp(brgemmOp.getOperand(0).getDefiningOp()); + if (brgemmOp.getOperand(0).getDefiningOp()->use_empty()) { + rewriter.eraseOp(brgemmOp.getOperand(0).getDefiningOp()); + } if (fusedMatch.binaryOp) { + assert(fusedMatch.binaryOp.use_empty()); rewriter.eraseOp(fusedMatch.binaryOp); - rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp()); + if (fusedMatch.binaryOp->getOperand(0).getDefiningOp()->use_empty()) { + rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp()); + } } if (fusedMatch.unaryOp) { + assert(fusedMatch.unaryOp.use_empty()); rewriter.eraseOp(fusedMatch.unaryOp); - rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp()); + if (fusedMatch.unaryOp->getOperand(0).getDefiningOp()->use_empty()) { + rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp()); + } } if (fusedMatch.zeroOp) { + assert(fusedMatch.zeroOp.use_empty()); rewriter.eraseOp(fusedMatch.zeroOp); - rewriter.eraseOp(fusedMatch.zeroOp->getOperand(0).getDefiningOp()); + if (fusedMatch.zeroOp->getOperand(0).getDefiningOp()->use_empty()) { + rewriter.eraseOp(fusedMatch.zeroOp->getOperand(0).getDefiningOp()); + } } return success(); } diff --git a/lib/TPP/Transforms/LoopExpansion.cpp b/lib/TPP/Transforms/LoopExpansion.cpp new file mode 100644 index 000000000..43ca9fd02 --- /dev/null +++ b/lib/TPP/Transforms/LoopExpansion.cpp @@ -0,0 +1,143 @@ +//===- LoopExpansion.cpp -----------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file splits parallel loop into scf fors. +// +//===----------------------------------------------------------------------===// +#include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tpp { +#define GEN_PASS_DECL_LOOPEXPANSIONPASS +#define GEN_PASS_DEF_LOOPEXPANSIONPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; +using namespace mlir::scf; +using namespace std; + +namespace mlir { +namespace tpp { + +struct LoopExpansion : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LoopExpansion(MLIRContext *ctx, LoopExpansionPassOptions &options) + : OpRewritePattern(ctx), options(options){}; + + LogicalResult matchAndRewrite(scf::ForallOp op, + PatternRewriter &rewriter) const override { + if (options.numOuterParallel > op.getInductionVars().size()) + return failure(); + + xsmm::BrgemmOp brgemmOp = NULL; + for (auto oper = op.getBody()->getOperations().begin(); + oper != op.getBody()->getOperations().end(); oper++) + if (dyn_cast(oper)) { + brgemmOp = dyn_cast(oper); + break; + } + if (brgemmOp == NULL) + return failure(); + auto ub = op.getStaticUpperBound().begin(); + auto step = op.getStaticStep().begin(); + rewriter.setInsertionPointAfter(op); + + SmallVector parallelOpList; + SmallVector forOpList; + size_t i = 0; + for (auto lb = op.getStaticLowerBound().begin(); + lb != op.getStaticLowerBound().end() && + ub != op.getStaticUpperBound().end() && + step != op.getStaticStep().end(); + lb++, ub++, step++) { + auto lowerBound = + rewriter.create(op.getLoc(), *lb); + auto upperBound = + rewriter.create(op.getLoc(), *ub); + auto stepVal = + rewriter.create(op.getLoc(), *step); + + if (i++ < options.numOuterParallel) { + auto parallelOp = rewriter.create( + op.getLoc(), ValueRange{lowerBound}, ValueRange{upperBound}, + ValueRange{stepVal}); + rewriter.setInsertionPoint(¶llelOp.getBody()->front()); + parallelOpList.push_back(parallelOp); + } else { + auto forOp = rewriter.create(op.getLoc(), lowerBound, + upperBound, stepVal); + rewriter.setInsertionPoint(&forOp.getBody()->front()); + forOpList.push_back(forOp); + } + } + + IRMapping mapping; + for (auto oper = op.getBody()->getOperations().begin(); + oper != op.getBody()->getOperations().end(); oper++) { + if (!isa(oper)) { + auto clonedInstr = rewriter.clone(*oper, mapping); + oper->replaceAllUsesWith(clonedInstr); + int j = 0; + for (auto arg : clonedInstr->getOperands()) { + for (size_t i = 0; i < op.getInductionVars().size(); i++) { + if (arg == op.getInductionVars()[i]) { + if (i < options.numOuterParallel) { + clonedInstr->setOperand( + j, parallelOpList[i].getInductionVars()[0]); + } else { + clonedInstr->setOperand( + j, + forOpList[i - options.numOuterParallel].getInductionVar()); + } + break; + } + } + j++; + } + } + } + rewriter.eraseOp(op); + return success(); + } + +private: + LoopExpansionPassOptions options; +}; + +struct LoopExpansionPass + : public impl::LoopExpansionPassBase { + + LoopExpansionPass() {} + + LoopExpansionPass(const LoopExpansionPassOptions &options) { + this->numOuterParallel = options.numOuterParallel; + } + + void populateCombinePatterns(RewritePatternSet &patterns, + LoopExpansionPassOptions options) { + patterns.add(patterns.getContext(), options); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateCombinePatterns(patterns, + LoopExpansionPassOptions{numOuterParallel}); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace tpp +} // namespace mlir diff --git a/lib/TPP/Transforms/LoopInsertion.cpp b/lib/TPP/Transforms/LoopInsertion.cpp new file mode 100644 index 000000000..1883178b6 --- /dev/null +++ b/lib/TPP/Transforms/LoopInsertion.cpp @@ -0,0 +1,276 @@ +//===- LoopInsertion.cpp -----------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements parallel loop insertion for tiling. +// +//===----------------------------------------------------------------------===// +#include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tpp { +#define GEN_PASS_DECL_LOOPINSERTIONPASS +#define GEN_PASS_DEF_LOOPINSERTIONPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; +using namespace mlir::scf; + +namespace mlir { +namespace tpp { + +static SmallVector +getReassociationIndices(ArrayRef origtensorShape, + SmallVector> tileShapes) { + SmallVector indices; + + size_t index = 0; + for (size_t i = 0; i < tileShapes.size(); i++) { + ReassociationIndices reassociationIndex; + for (size_t j = 0; j < tileShapes[i].size(); j++) + reassociationIndex.push_back(index++); + indices.push_back(reassociationIndex); + } + for (size_t i = tileShapes.size(); i < origtensorShape.size(); i++) { + ReassociationIndices reassociationIndex; + reassociationIndex.push_back(index++); + indices.push_back(reassociationIndex); + } + + return indices; +} + +void insertSubview(ArrayRef tensorShape, Type type, Type resultType, + SmallVector reassociation, + Value operand, ForallOp op, OpBuilder b, + xsmm::BrgemmOp brgemmOp, int operandNumber) { + + auto expandShape = b.create( + op.getLoc(), + MemRefType::get({tensorShape}, + dyn_cast(type).getElementType()), + operand, reassociation); + expandShape.setStaticOutputShape(tensorShape); + b.setInsertionPoint(brgemmOp); + SmallVector strides(tensorShape.size(), b.getIndexAttr(1)), + sizes, offsets; + size_t tileSize = + tensorShape.size() - dyn_cast(resultType).getShape().size(); + + SmallVector tileSizes; + for (size_t i = 0; i < tensorShape.size(); i++) { + if (i < tileSize) { + int opnum = operandNumber; + if (opnum == 3) { + opnum = 1; + } + int inductionVarIndex = (opnum - 1) * tileSize + i; + offsets.push_back(op.getInductionVars()[inductionVarIndex]); + sizes.push_back(b.getIndexAttr(1)); + } else { + sizes.push_back(b.getIndexAttr(tensorShape[i])); + tileSizes.push_back(tensorShape[i]); + offsets.push_back(b.getIndexAttr(0)); + } + } + + auto subviewType = + MemRefType::get({tileSizes}, dyn_cast(type).getElementType()); + auto [originalStride, originalOffset] = + getStridesAndOffset(dyn_cast(subviewType)); + subviewType = MemRefType::get( + {tileSizes}, dyn_cast(subviewType).getElementType(), + StridedLayoutAttr::get(b.getContext(), ShapedType::kDynamic, + originalStride)); + auto subview = b.create( + op.getLoc(), dyn_cast(subviewType), expandShape.getResult(), + offsets, sizes, strides); + brgemmOp.getOperand(operandNumber).replaceAllUsesWith(subview); +} + +FailureOr insertParallelLoop(ForallOp op, + ArrayRef tileShapeM, + ArrayRef tileShapeN) { + xsmm::BrgemmOp brgemmOp = NULL; + OpBuilder b(op); + for (auto oper = op.getBody()->getOperations().begin(); + oper != op.getBody()->getOperations().end(); oper++) + if (dyn_cast(oper)) { + brgemmOp = dyn_cast(oper); + break; + } + if (brgemmOp == NULL) + return failure(); + + int boundSize = tileShapeM.size() + tileShapeN.size(); + auto mShape = + dyn_cast( + brgemmOp.getOperand(1).getDefiningOp()->getOperand(0).getType()) + .getShape(); + + // Validate the input tile sizes against the operand shapes + long multipleM = 1; + for (size_t i = 0; i < tileShapeM.size(); i++) + multipleM = multipleM * tileShapeM[i]; + + if (mShape[0] != multipleM) + return failure(); + + auto nShape = + dyn_cast( + brgemmOp.getOperand(2).getDefiningOp()->getOperand(0).getType()) + .getShape(); + + long multipleN = 1; + for (size_t i = 0; i < tileShapeN.size(); i++) + multipleN = multipleN * tileShapeN[i]; + + if (nShape[0] != multipleN) + return failure(); + + auto kShape = + dyn_cast( + brgemmOp.getOperand(3).getDefiningOp()->getOperand(0).getType()) + .getShape(); + + if ((multipleM * multipleN) != (kShape[0] * kShape[1])) + return failure(); + + // Set the new bounds of for loop + SmallVector lbs(boundSize, 0), steps(boundSize, 1); + + SmallVector ubs(tileShapeM.begin(), tileShapeM.end()); + ubs.append(tileShapeN.begin(), tileShapeN.end()); + + op.setStaticLowerBound(lbs); + op.setStaticUpperBound(ubs); + op.setStaticStep(steps); + + // Add new induction var args to the for loop + int numArgs = op.getBody()->getArguments().size(); + + for (int i = 0; i < boundSize - numArgs; i++) + op.getBody()->addArgument(b.getIndexType(), op.getLoc()); + + SmallVector tileOffsets{ + 0, static_cast(tileShapeM.size() - 1), + static_cast(tileShapeN.size() + tileShapeM.size() - 1)}; + b.setInsertionPoint(&op.getBody()->front()); + // Replace old args with newly computed args + for (auto oper = op.getBody()->getOperations().begin(); + oper != op.getBody()->getOperations().end(); oper++) { + int operandIndex = 0; + for (auto arg : oper->getOperands()) { + int oldArgIndex = -1; + for (int i = 0; i < numArgs; i++) { + if (arg == op.getBody()->getArgument(i)) { + oldArgIndex = i; + break; + } + } + if (oldArgIndex != -1) { + Value mul, add = NULL; + for (int j = tileOffsets[oldArgIndex + 1]; j > tileOffsets[oldArgIndex]; + j--) { + Value index = b.create( + op.getLoc(), op.getStaticUpperBound()[j - 1]); + mul = b.create(op.getLoc(), b.getIndexType(), + op.getBody()->getArgument(j), index); + add = b.create(op.getLoc(), b.getIndexType(), mul, + op.getBody()->getArgument(j - 1)); + } + assert(add != NULL); + oper->setOperand(operandIndex, add); + } + operandIndex++; + } + } + + SmallVector> originalShapes{mShape, nShape, kShape}; + SmallVector>> tilingVectors{ + {tileShapeM}, {tileShapeN}, {tileShapeM, tileShapeN}}; + + for (int i = 1; i <= 3; i++) { + auto operand = brgemmOp.getOperand(i).getDefiningOp()->getOperand(0); + auto operandType = operand.getType(); + auto resultType = dyn_cast(brgemmOp.getOperand(i).getType()); + auto reassociationIndex = + getReassociationIndices(originalShapes[i - 1], tilingVectors[i - 1]); + + SmallVector shape; + for (size_t j = 0; j < tilingVectors[i - 1].size(); j++) { + shape.append(tilingVectors[i - 1][j].begin(), + tilingVectors[i - 1][j].end()); + } + shape.append( + std::next(originalShapes[i - 1].begin(), tilingVectors[i - 1].size()), + originalShapes[i - 1].end()); + insertSubview(shape, operandType, resultType, reassociationIndex, operand, + op, b, brgemmOp, i); + } + + return op; +} + +bool getInnermostForLoops(Operation *rootOp, + SmallVectorImpl &result) { + assert(rootOp != nullptr && "Root operation must not be a nullptr."); + bool rootEnclosesForAllloops = false; + for (Region ®ion : rootOp->getRegions()) { + for (Block &block : region.getBlocks()) { + for (Operation &op : block) { + bool enclosesPloops = getInnermostForLoops(&op, result); + rootEnclosesForAllloops |= enclosesPloops; + if (auto ploop = dyn_cast(op)) { + rootEnclosesForAllloops = true; + + // Collect forall loop if it is an innermost one. + if (!enclosesPloops) + result.push_back(ploop); + } + } + } + } + return rootEnclosesForAllloops; +} + +struct LoopInsertionPass + : public tpp::impl::LoopInsertionPassBase { + + LoopInsertionPass(){}; + + LoopInsertionPass(ArrayRef tileShapeM, + ArrayRef tileShapeN) { + this->tileShapeM = tileShapeM; + this->tileShapeN = tileShapeN; + }; + + LoopInsertionPass(const tpp::LoopInsertionPassOptions &options) { + tileShapeM = options.tileShapeM; + tileShapeN = options.tileShapeN; + }; + + void runOnOperation() override { + auto *parentOp = getOperation(); + SmallVector innermostForAllloops; + getInnermostForLoops(parentOp, innermostForAllloops); + for (ForallOp loop : innermostForAllloops) { + if (failed(insertParallelLoop(loop, tileShapeM, tileShapeN))) { + return; + } + } + } +}; +} // namespace tpp +} // namespace mlir diff --git a/lib/TPP/Transforms/LoopShuffle.cpp b/lib/TPP/Transforms/LoopShuffle.cpp new file mode 100644 index 000000000..262c12d3e --- /dev/null +++ b/lib/TPP/Transforms/LoopShuffle.cpp @@ -0,0 +1,128 @@ +//===- LoopShuffle.cpp -----------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file shuffles parallel loop based on user input +// +//===----------------------------------------------------------------------===// +#include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include + +namespace mlir { +namespace tpp { +#define GEN_PASS_DECL_LOOPSHUFFLEPASS +#define GEN_PASS_DEF_LOOPSHUFFLEPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; +using namespace mlir::scf; +using namespace std; + +namespace mlir { +namespace tpp { + +struct LoopShuffle : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LoopShuffle(MLIRContext *ctx, LoopShufflePassOptions &options) + : OpRewritePattern(ctx), options(options) {} + + LogicalResult matchAndRewrite(scf::ForallOp op, + PatternRewriter &rewriter) const override { + + if (options.shuffleOrder.size() != op.getInductionVars().size()) + return failure(); + for (size_t i = 0; i < op.getInductionVars().size(); i++) { + bool match = false; + for (size_t j = 0; j < options.shuffleOrder.size(); j++) + if (i == options.shuffleOrder[j]) { + match = true; + break; + } + if (!match) { + return failure(); + } + } + xsmm::BrgemmOp brgemmOp = NULL; + static list visitedForallOp; + if (std::find(visitedForallOp.begin(), visitedForallOp.end(), op) != + visitedForallOp.end()) + return failure(); + for (auto oper = op.getBody()->getOperations().begin(); + oper != op.getBody()->getOperations().end(); oper++) + if (dyn_cast(oper)) { + brgemmOp = dyn_cast(oper); + break; + } + if (brgemmOp == NULL) + return failure(); + SmallVector lbs, ubs, steps; + + for (size_t i = 0; i < op.getStaticLowerBound().size(); i++) { + lbs.push_back(op.getStaticLowerBound()[options.shuffleOrder[i]]); + } + for (size_t i = 0; i < op.getStaticUpperBound().size(); i++) { + ubs.push_back(op.getStaticUpperBound()[options.shuffleOrder[i]]); + } + for (size_t i = 0; i < op.getStaticStep().size(); i++) { + steps.push_back(op.getStaticStep()[options.shuffleOrder[i]]); + } + + op.setStaticLowerBound(lbs); + op.setStaticUpperBound(ubs); + op.setStaticStep(steps); + SmallVector tempVector; + for (size_t i = 0; i < op.getInductionVars().size(); i++) { + auto tempValue = rewriter.create(op.getLoc(), i); + replaceAllUsesInRegionWith(op.getInductionVar(i), tempValue, + op.getRegion()); + tempVector.push_back(tempValue); + } + for (size_t i = 0; i < op.getInductionVars().size(); i++) { + replaceAllUsesInRegionWith(tempVector[i], + op.getInductionVar(options.shuffleOrder[i]), + op.getRegion()); + } + visitedForallOp.push_back(op); + return success(); + } + +private: + LoopShufflePassOptions options; +}; + +struct LoopShufflePass : public impl::LoopShufflePassBase { + + LoopShufflePass() {} + + LoopShufflePass(const LoopShufflePassOptions &options) { + this->shuffleOrder = options.shuffleOrder; + } + + void populateCombinePatterns(RewritePatternSet &patterns, + LoopShufflePassOptions options) { + patterns.add(patterns.getContext(), options); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateCombinePatterns(patterns, + LoopShufflePassOptions{this->shuffleOrder}); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace tpp +} // namespace mlir diff --git a/test/Passes/pass-convert-gemm-to-parallel-tile.mlir b/test/Passes/pass-convert-gemm-to-parallel-tile.mlir index f9e132da4..0bdef8c37 100644 --- a/test/Passes/pass-convert-gemm-to-parallel-tile.mlir +++ b/test/Passes/pass-convert-gemm-to-parallel-tile.mlir @@ -17,15 +17,15 @@ module { } // CHECK: func.func @_entry(%[[ARG0:.*]]: memref<8x32x32x32xf32>, %[[ARG1:.*]]: memref<32x32x32x32xf32>, %[[ARG2:.*]]: memref<8x32x32x32xf32>) { -// CHECK: %[[c2:.*]] = arith.constant 2 : index -// CHECK: %[[c32_i64:.*]] = arith.constant 32 : i64 -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[c8:.*]] = arith.constant 8 : index -// CHECK: %[[c32:.*]] = arith.constant 32 : index -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[c1_i64:.*]] = arith.constant 1 : i64 -// CHECK: %[[c1024_i64:.*]] = arith.constant 1024 : i64 -// CHECK: %[[c0_i64:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c1024_i64:.*]] = arith.constant 1024 : i64 +// CHECK-DAG: %[[c0_i64:.*]] = arith.constant 0 : i64 // CHECK: %[[temp0:.*]] = call @xsmm_brgemm_dispatch(%[[c1_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c1024_i64]], %[[c1024_i64]], %[[c0_i64]]) // CHECK: omp.parallel { // CHECK: omp.wsloop { diff --git a/test/Passes/pass-convert-mlp-to-parallel-tile.mlir b/test/Passes/pass-convert-mlp-to-parallel-tile.mlir index 09567de45..b4462fcdb 100644 --- a/test/Passes/pass-convert-mlp-to-parallel-tile.mlir +++ b/test/Passes/pass-convert-mlp-to-parallel-tile.mlir @@ -65,18 +65,18 @@ module { //CHECK: func.func @_entry(%[[ARG0:.*]]: memref<8x32x32x32xf32>, %[[ARG1:.*]]: memref<32x32x32x32xf32>, %[[ARG2:.*]]: memref<32x32xf32>, %[[ARG3:.*]]: memref<8x32x32x32xf32>, %[[ARG4:.*]]: memref<32x32x32x32xf32>, %[[ARG5:.*]]: memref<32x32xf32>, %[[ARG6:.*]]: memref<8x32x32x32xf32>, %[[ARG7:.*]]: memref<32x32x32x32xf32>, %[[ARG8:.*]]: memref<32x32xf32>, %[[ARG9:.*]]: memref<8x32x32x32xf32>) { -//CHECK: %[[c16:.*]] = arith.constant 16 : index -//CHECK: %[[c2:.*]] = arith.constant 2 : index -//CHECK: %[[c32_i64:.*]] = arith.constant 32 : i64 -//CHECK: %[[c0:.*]] = arith.constant 0 : index -//CHECK: %[[c8:.*]] = arith.constant 8 : index -//CHECK: %[[c32:.*]] = arith.constant 32 : index -//CHECK: %[[c1:.*]] = arith.constant 1 : index -//CHECK: %[[c1_i64:.*]] = arith.constant 1 : i64 -//CHECK: %[[c1024_i64:.*]] = arith.constant 1024 : i64 -//CHECK: %[[c0_i64:.*]] = arith.constant 0 : i64 -//CHECK: %[[c5_i64:.*]] = arith.constant 5 : i64 -//CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 +//CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index +//CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +//CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 +//CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +//CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +//CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +//CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +//CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +//CHECK-DAG: %[[c1024_i64:.*]] = arith.constant 1024 : i64 +//CHECK-DAG: %[[c0_i64:.*]] = arith.constant 0 : i64 +//CHECK-DAG: %[[c5_i64:.*]] = arith.constant 5 : i64 +//CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 //CHECK: %[[temp0:.*]] = call @xsmm_fused_brgemm_dispatch(%[[c1_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c1024_i64]], %[[c1024_i64]], %[[c0_i64]], %[[c0_i64]], %[[c5_i64]], %[[c4_i64]], %[[c1_i64]]) //CHECK: omp.parallel { //CHECK: omp.wsloop {