Skip to content

Commit

Permalink
Loop tiling, shuffle and expansion passes
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Jun 12, 2024
1 parent f307b5a commit 4699e3e
Show file tree
Hide file tree
Showing 13 changed files with 717 additions and 42 deletions.
20 changes: 17 additions & 3 deletions include/TPP/PassBundles.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> {
"bool", /*default=*/"false",
"Skip all TPP transformations. Lower linalg directly to loops.">,
ListOption<"parallelTaskGrid", "parallel-task-grid",
"unsigned", "Grid-sizes for parallel tasks.">
"unsigned", "Grid-sizes for parallel tasks.">,
ListOption<"tileShapeM", "M-tile-shape", "unsigned",
"Shape to reshape the M tensor into">,
ListOption<"tileShapeN", "N-tile-shape", "unsigned",
"Shape to reshape the N tensor into">,
ListOption<"shuffleOrder", "loop-shuffle-order", "unsigned",
"Shuffle order of scf for all loop surrounding brgemm op">,
Option<"outerParallelLoops", "num-outer-parallel", "unsigned", "0",
"Number of outer loops to be parallelized">

];
}
Expand Down Expand Up @@ -68,8 +76,14 @@ def LowLevelParallelization : Pass<"low-level-parallel", "ModuleOp"> {
"xsmm::XsmmDialect",
"LLVM::LLVMDialect"];
let options = [
ListOption<"parallelTaskGrid", "parallel-task-grid",
"unsigned", "Grid-sizes for parallel tasks.">
ListOption<"tileShapeM", "M-tile-shape", "unsigned",
"Shape to reshape the M tensor into">,
ListOption<"tileShapeN", "N-tile-shape", "unsigned",
"Shape to reshape the N tensor into">,
ListOption<"shuffleOrder", "loop-shuffle-order", "unsigned",
"Shuffle order of scf for all loop surrounding brgemm op">,
Option<"outerParallelLoops", "num-outer-parallel", "unsigned","0",
"Number of outer loops to be parallelized">

];
}
Expand Down
39 changes: 39 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,45 @@ def ConvertAddInplacePass: Pass<"linalg-convert-add-in-place",
let dependentDialects = ["linalg::LinalgDialect"];
}

def LoopInsertionPass: Pass<"loop-insertion-pass">{
let summary = "Insert loop around brgemm parallel op";
let description = [{
Insert nd parallel loop around brgemm parallel loop.
}];
let dependentDialects = ["scf::SCFDialect" , "xsmm::XsmmDialect"];
let options = [
ListOption<"tileShapeM", "M-tile-shape", "unsigned",
"Shape to reshape the M tensor into">,
ListOption<"tileShapeN", "N-tile-shape", "unsigned",
"Shape to reshape the N tensor into">
];
}

def LoopExpansionPass: Pass<"loop-expansion-pass">{
let summary = "Expand brgemm parallel op";
let description = [{
Expand nd parallel loop.
}];
let options = [
Option<"numOuterParallel", "num-outer-parallel", "unsigned",
"0", "Number of outer Parallel Loops">
];
let dependentDialects = ["scf::SCFDialect"];
}

def LoopShufflePass: Pass<"loop-shuffle-pass">{
let summary = "Shuffle brgemm parallel op";
let description = [{
Shuffle nd brgemm parallel loop.
}];
let options = [
ListOption<"shuffleOrder", "shuffle-order", "unsigned",
"Order to shuffle the parallel loop by">
];

let dependentDialects = ["scf::SCFDialect"];
}

def TppRunnerWrapper : Pass<"tpp-runner-wrapper", "ModuleOp">{
let summary = "Create main function runner wrapper";
let description = [{
Expand Down
31 changes: 28 additions & 3 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ llvm::cl::list<unsigned>
llvm::cl::list_init<unsigned>(SmallVector<unsigned>{2, 8}),
llvm::cl::CommaSeparated);

llvm::cl::list<unsigned> tileShapeM("M-tile-shape",
llvm::cl::desc("Tile shape of M tensor"),
llvm::cl::CommaSeparated);

llvm::cl::list<unsigned> tileShapeN("N-tile-shape",
llvm::cl::desc("Tile shape of N tensor"),
llvm::cl::CommaSeparated);

llvm::cl::list<unsigned> shuffleOrder(
"loop-shuffle-order",
llvm::cl::desc("shuffle order of scf for all loop surrounding brgemm op"),
llvm::cl::CommaSeparated);

llvm::cl::opt<unsigned> outerParallelLoops(
"num-outer-parallel",
llvm::cl::desc("Number of outer loops to be parallelized"),
llvm::cl::value_desc("int"), llvm::cl::init(0));

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_DEFAULTPIPELINE
Expand Down Expand Up @@ -124,9 +142,16 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend}));
} else {
// Apply the default preprocessing pass
DefaultTppPassesOptions tppDefaultOptions{linalgToLoops,
parallelTaskGrid};
pm.addPass(createDefaultTppPasses(tppDefaultOptions));
if (!tileShapeM.empty() && !tileShapeN.empty()) {
DefaultTppPassesOptions tppDefaultOptions{
linalgToLoops, parallelTaskGrid, tileShapeM,
tileShapeN, shuffleOrder, outerParallelLoops};
pm.addPass(createDefaultTppPasses(tppDefaultOptions));
} else {
DefaultTppPassesOptions tppDefaultOptions{linalgToLoops,
parallelTaskGrid};
pm.addPass(createDefaultTppPasses(tppDefaultOptions));
}
}

if (print == PrintStage::Mid)
Expand Down
33 changes: 27 additions & 6 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "TPP/Dialect/Check/BufferizableOpInterfaceImpl.h"
#include "TPP/Dialect/Check/CheckDialect.h"
Expand Down Expand Up @@ -101,13 +102,33 @@ struct DefaultTppPasses
pm.addPass(createCleanup());
}

// Convert forAll to parallel loops should run after bufferization
// as scf.parallel does not handle tensor.
pm.addPass(createConvertForAllToParallelOp());
// Low level parallelization passes.
if (!tileShapeM.empty() && !tileShapeN.empty()) {
LowLevelParallelizationOptions LowLevelParallelization(
LowLevelParallelizationOptions{tileShapeM, tileShapeN, shuffleOrder,
outerParallelLoops});
pm.addPass(createLowLevelParallelization(LowLevelParallelization));

// Low leve parallelization passes.
LowLevelParallelizationOptions LowLevelParallelization{parallelTaskGrid};
pm.addPass(createLowLevelParallelization(LowLevelParallelization));
// Convert forAll to parallel loops should run after bufferization
// as scf.parallel does not handle tensor.
pm.addPass(createConvertForAllToParallelOp());
} else {
// FIXME remove as soon as the above code is fixed
pm.addPass(createConvertForAllToParallelOp());
mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
pm.addPass(createSCFParallelLoopTiling(tilingOptions));

pm.addNestedPass<func::FuncOp>(createIntelAMXTileConfigInsertionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createIntelAMXTileConfigHoistingPass());
pm.addPass(createCombineXsmmOpPass());
pm.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
pm.addPass(createFoldXsmmFlags());
pm.addPass(createVerifyXsmmCalls());
}

// Covert all local TPP-related dialects.
pm.addPass(createLocalDialectsLowering());
Expand Down
2 changes: 0 additions & 2 deletions lib/TPP/PassBundles/LinalgLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ struct LinalgLowering : public tpp::impl::LinalgLoweringBase<LinalgLowering>,
private:
void constructPipeline() override {
pm.addPass(createConvertLinalgToXsmm());
pm.addPass(createCombineXsmmOpPass());
pm.addPass(createFoldXsmmFlags());
pm.addPass(createVerifyXsmmCalls());
}
};
20 changes: 17 additions & 3 deletions lib/TPP/PassBundles/LowLevelParallelization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,28 @@ struct LowLevelParallelization
// that they are hoisted out of loops.
pm.addPass(createCleanup());

mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
pm.addPass(createSCFParallelLoopTiling(tilingOptions));
mlir::tpp::LoopInsertionPassOptions loopInsertionPassOptions;
loopInsertionPassOptions.tileShapeM = tileShapeM;
loopInsertionPassOptions.tileShapeN = tileShapeN;
pm.addPass(createLoopInsertionPass(loopInsertionPassOptions));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

mlir::tpp::LoopShufflePassOptions loopShufflePassOptions;
loopShufflePassOptions.shuffleOrder = shuffleOrder;
pm.addPass(createLoopShufflePass(loopShufflePassOptions));

mlir::tpp::LoopExpansionPassOptions loopExpansionPassOptions;
loopExpansionPassOptions.numOuterParallel = outerParallelLoops;
pm.addPass(createLoopExpansionPass(loopExpansionPassOptions));

pm.addNestedPass<func::FuncOp>(createIntelAMXTileConfigInsertionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createIntelAMXTileConfigHoistingPass());
pm.addPass(createCombineXsmmOpPass());
pm.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
pm.addPass(createFoldXsmmFlags());
pm.addPass(createVerifyXsmmCalls());
}
};
3 changes: 3 additions & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ add_mlir_library(TPPTransforms
IntelAMXTileConfigHoisting.cpp
LinalgConvertCompareSelectToMaximumfPass.cpp
ConvertAddInplacePass.cpp
LoopInsertion.cpp
LoopExpansion.cpp
LoopShuffle.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
20 changes: 16 additions & 4 deletions lib/TPP/Transforms/CombineXsmmPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,31 @@ struct CombineXsmmOp : public OpRewritePattern<xsmm::BrgemmOp> {

// Replace and delete the old invokes and their dispatches
rewriter.create<xsmm::FusedBrgemmOp>(loc, dtype, invokeOperands);
assert(brgemmOp.use_empty());
rewriter.eraseOp(brgemmOp);
rewriter.eraseOp(brgemmOp.getOperand(0).getDefiningOp());
if (brgemmOp.getOperand(0).getDefiningOp()->use_empty()) {
rewriter.eraseOp(brgemmOp.getOperand(0).getDefiningOp());
}
if (fusedMatch.binaryOp) {
assert(fusedMatch.binaryOp.use_empty());
rewriter.eraseOp(fusedMatch.binaryOp);
rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp());
if (fusedMatch.binaryOp->getOperand(0).getDefiningOp()->use_empty()) {
rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp());
}
}
if (fusedMatch.unaryOp) {
assert(fusedMatch.unaryOp.use_empty());
rewriter.eraseOp(fusedMatch.unaryOp);
rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp());
if (fusedMatch.unaryOp->getOperand(0).getDefiningOp()->use_empty()) {
rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp());
}
}
if (fusedMatch.zeroOp) {
assert(fusedMatch.zeroOp.use_empty());
rewriter.eraseOp(fusedMatch.zeroOp);
rewriter.eraseOp(fusedMatch.zeroOp->getOperand(0).getDefiningOp());
if (fusedMatch.zeroOp->getOperand(0).getDefiningOp()->use_empty()) {
rewriter.eraseOp(fusedMatch.zeroOp->getOperand(0).getDefiningOp());
}
}
return success();
}
Expand Down
Loading

0 comments on commit 4699e3e

Please sign in to comment.