Skip to content

Commit

Permalink
Tile configuration addition pass
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Feb 26, 2024
1 parent 10674df commit b6137c5
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 8 deletions.
6 changes: 4 additions & 2 deletions include/TPP/Dialect/Xsmm/XsmmEnum.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def Xsmm_GemmFlags : I64EnumAttr<
I64EnumAttrCase<"BETA_0", 4, "beta_0">,
I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">,
I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">,
I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">
]> {
I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">,
I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">,
I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig">
]> {
let cppNamespace = "mlir::xsmm";
}
19 changes: 19 additions & 0 deletions include/TPP/Dialect/Xsmm/XsmmOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,15 @@ def Xsmm_GemmDispatchOp : Xsmm_GemmLikeOp<"gemm.dispatch"> {
def Xsmm_BrgemmDispatchOp : Xsmm_GemmLikeOp<"brgemm.dispatch"> {
let summary = "dispatch for brgemm operation.";
let hasVerifier = 1;

}

//===----------------------------------------------------------------------===//
// TileConfigDispatchOp
//===----------------------------------------------------------------------===//

def Xsmm_TileConfigDispatchOp : Xsmm_GemmLikeOp<"tileConfig.dispatch"> {
let summary = "dispatch for tileConfig operation.";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -298,4 +307,14 @@ def Xsmm_FusedBrgemmDispatchOp : Xsmm_Op<"fused_brgemm.dispatch", [Pure]> {
let hasVerifier = 1;
}


//===----------------------------------------------------------------------===//
// TileConfigOp
//===----------------------------------------------------------------------===//

def Xsmm_TileConfigOp : Xsmm_Op<"tileConfig", [MemoryEffects<[MemWrite, MemRead]>]> {
let summary = "invoke for tileConfig operation.";
let arguments = (ins I64:$dispatch, Variadic<AnyMemRef>:$inputs);
}

#endif // TPP_XSMM_OPS
11 changes: 11 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def FoldXsmmFlags : Pass<"fold-xsmm-flags", "func::FuncOp"> {
let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}


def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> {
let summary = "Tile parallel loops";
let options = [
Expand All @@ -483,4 +484,14 @@ def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> {
let dependentDialects = ["affine::AffineDialect", "scf::SCFDialect"];
}

def TileConfigInsertionPass : Pass<"tile-config-insertion-pass",
"func::FuncOp"> {
let summary = "Insert tile configuration xsmm calls";
let description = [{
Insert tile configuration xsmm calls and perform LICM on them.
}];

let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}

#endif // TPP_DIALECT_TPP_PASSES
34 changes: 32 additions & 2 deletions lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ struct ConvertFusedBrgemmXsmmOp : public OpRewritePattern<FusedBrgemmOp> {
}
};

struct ConvertTileConfigXsmmOp : public OpRewritePattern<TileConfigOp> {
using OpRewritePattern<TileConfigOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TileConfigOp tileConfigOp,
PatternRewriter &rewriter) const override {
std::string funcName = "xsmm_tile_config_invoke";
buildInvokeCall(rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp,
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16));
rewriter.eraseOp(tileConfigOp);
return success();
}
};

static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc,
ArrayRef<Value> dispatchOperands,
ArrayRef<Type> dispatchOperandTypes,
Expand Down Expand Up @@ -227,6 +240,13 @@ void addKindOperand(RewriterBase &rewriter, FusedBrgemmDispatchOp dispatchOp,
/* do nothing */
}

void addKindOperand(RewriterBase &rewriter, TileConfigDispatchOp dispatchOp,
SmallVectorImpl<Value> &dispatchOperands,
SmallVectorImpl<Type> &dispatchOperandTypes) {
/* do nothing */
}


static int64_t getOredFlags(ArrayAttr flags) {
int64_t oredFlag = 0;
for (auto flag : flags) {
Expand Down Expand Up @@ -370,6 +390,16 @@ struct ConvertUnaryDispatchOp : public OpRewritePattern<UnaryDispatchOp> {
}
};

struct ConvertTileConfigDispatchOp : public OpRewritePattern<TileConfigDispatchOp> {
using OpRewritePattern<TileConfigDispatchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TileConfigDispatchOp dispatchOp,
PatternRewriter &rewriter) const override {
return buildDispatchOp<TileConfigDispatchOp>(rewriter, dispatchOp,
"xsmm_tile_config_dispatch");
}
};

struct ConvertFusedBrgemmOp : public OpRewritePattern<FusedBrgemmDispatchOp> {
using OpRewritePattern<FusedBrgemmDispatchOp>::OpRewritePattern;

Expand All @@ -395,11 +425,11 @@ struct ConvertXsmmToFunc
RewritePatternSet patterns(&getContext());
patterns
.add<ConvertBinaryXsmmOp, ConvertUnaryXsmmOp,
ConvertGemmXsmmOp, ConvertBrgemmXsmmOp, ConvertFusedBrgemmXsmmOp>(
ConvertGemmXsmmOp, ConvertBrgemmXsmmOp, ConvertFusedBrgemmXsmmOp, ConvertTileConfigXsmmOp>(
patterns.getContext());
patterns.add<ConvertBinaryDispatchOp,
ConvertUnaryDispatchOp, ConvertGemmDispatchOp,
ConvertBrgemmDispatchOp, ConvertFusedBrgemmOp>(
ConvertBrgemmDispatchOp, ConvertFusedBrgemmOp, ConvertTileConfigDispatchOp>(
patterns.getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
pm.addPass(createSCFParallelLoopTiling(tilingOptions));
pm.addNestedPass<func::FuncOp>(createTileConfigInsertionPass());
pm.addPass(createConvertSCFToOpenMPPass());
}
pm.addPass(createConvertVectorToSCFPass());
Expand Down
15 changes: 15 additions & 0 deletions lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,21 @@ void BinaryDispatchOp::print(OpAsmPrinter &printer) {
printerDataTypeImpl<BinaryDispatchOp>(printer, *this);
}

void TileConfigDispatchOp::print(OpAsmPrinter &printer) {
printerInputImpl<TileConfigDispatchOp>(printer, *this);
auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); };
printerFlagsImpl<GemmFlagsAttr>(printer, getOpFlags, FLAGS_NAME);
printerDataTypeImpl<TileConfigDispatchOp>(printer, *this);
}

ParseResult TileConfigDispatchOp::parse(OpAsmParser &parser,
OperationState &result) {
if (failed(parseInputImpl(parser, result)) ||
failed(parserFlagsImpl<GemmFlags>(parser, result, FLAGS_NAME)))
return failure();
return parseDataTypeImpl(parser, result);
}

template <typename FLAGS>
static LogicalResult
verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op,
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_library(TPPTransforms
TransformUtils.cpp
CombineXsmmPass.cpp
SCFParallelLoopTiling.cpp
TileConfig.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
154 changes: 154 additions & 0 deletions lib/TPP/Transforms/TileConfig.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements loop tiling on parallel loops.
//
//===----------------------------------------------------------------------===//
#include "TPP/Dialect/Xsmm/XsmmOps.h"
#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <iostream>
#include <list>
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_TILECONFIGINSERTIONPASS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;

namespace mlir {
namespace tpp {

template <typename InvokeOpTy, typename DispatchOpTy>
static void appendBrgemmFlags(SmallVector<Attribute> &attributes,
PatternRewriter &rewriter, InvokeOpTy opTy) {
auto flags =
dyn_cast<DispatchOpTy>(opTy.getOperand(0).getDefiningOp()).getFlags();
for (auto flagItr : flags) {
if (flagItr == xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::NONE)) {
return;
}
attributes.push_back(flagItr);
}

if (attributes.empty())
attributes.push_back(
xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE));
}

template <typename InvokeOpTy, typename DispatchOpTy>
struct TileConfig : OpRewritePattern<InvokeOpTy> {
using OpRewritePattern<InvokeOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(InvokeOpTy op,
PatternRewriter &rewriter) const override {
if (xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()) !=
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16))
return failure();
auto flags =
dyn_cast<DispatchOpTy>(op.getOperand(0).getDefiningOp()).getFlags();
for (auto flagItr : flags) {
if (flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG) ||
flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) {
return failure();
}
}

SmallVector<Attribute> attributesSetup;
attributesSetup.push_back(xsmm::GemmFlagsAttr::get(
rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG));
appendBrgemmFlags<InvokeOpTy, DispatchOpTy>(attributesSetup, rewriter, op);
auto tileConfigSetup = rewriter.create<xsmm::TileConfigDispatchOp>(
op.getLoc(), rewriter.getI64Type(),
DenseI64ArrayAttr::get(
rewriter.getContext(),
dyn_cast<DispatchOpTy>(op.getOperand(0).getDefiningOp())
.getInputs()),
rewriter.getArrayAttr(attributesSetup),
xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()));

SmallVector<Attribute> attributesReset;
attributesReset.push_back(xsmm::GemmFlagsAttr::get(
rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG));
appendBrgemmFlags<InvokeOpTy, DispatchOpTy>(attributesReset, rewriter, op);
auto tileConfigReset = rewriter.create<xsmm::TileConfigDispatchOp>(
op.getLoc(), rewriter.getI64Type(),
DenseI64ArrayAttr::get(
rewriter.getContext(),
dyn_cast<DispatchOpTy>(op.getOperand(0).getDefiningOp())
.getInputs()),
rewriter.getArrayAttr(attributesReset),
xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()));

SmallVector<Attribute> attributesBrgemm;
attributesBrgemm.push_back(xsmm::GemmFlagsAttr::get(
rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG));
attributesBrgemm.push_back(xsmm::GemmFlagsAttr::get(
rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG));
appendBrgemmFlags<InvokeOpTy, DispatchOpTy>(attributesBrgemm, rewriter, op);

auto dispatch = dyn_cast<DispatchOpTy>(
rewriter.clone(*op.getOperand(0).getDefiningOp()));
dispatch.setFlagsAttr(rewriter.getArrayAttr(attributesBrgemm));

auto alloca = rewriter.create<memref::AllocaOp>(
op.getLoc(), MemRefType::get({64}, rewriter.getI8Type()));

ValueRange tileConfigInputs{alloca};
rewriter.create<mlir::xsmm::TileConfigOp>(
op.getLoc(), tileConfigSetup, tileConfigInputs);

SmallVector<Value> invokeOperands;
invokeOperands.push_back(dispatch);
auto opItr = op->getOperands().begin();
std::advance(opItr, 1);
invokeOperands.append(opItr, op->getOperands().end());
rewriter.create<InvokeOpTy>(
op.getLoc(),
xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()),
invokeOperands);

ValueRange tileResetInputs{alloca};
rewriter.create<mlir::xsmm::TileConfigOp>(
op.getLoc(), tileConfigReset, tileResetInputs);

rewriter.eraseOp(op);
rewriter.eraseOp(op.getOperand(0).getDefiningOp());
return success();
}
};

struct TileConfigInsertionPass
: public impl::TileConfigInsertionPassBase<TileConfigInsertionPass> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<TileConfig<xsmm::BrgemmOp, xsmm::BrgemmDispatchOp>>(
patterns.getContext());
patterns.add<TileConfig<xsmm::FusedBrgemmOp, xsmm::FusedBrgemmDispatchOp>>(
patterns.getContext());
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateCombinePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace tpp
} // namespace mlir
Loading

0 comments on commit b6137c5

Please sign in to comment.