Skip to content

Commit

Permalink
Tile config changed to intel amx tile config
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Feb 28, 2024
1 parent 951f7e5 commit 96d146d
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 92 deletions.
12 changes: 6 additions & 6 deletions include/TPP/Dialect/Xsmm/XsmmOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,11 @@ def Xsmm_BrgemmDispatchOp : Xsmm_GemmLikeOp<"brgemm.dispatch"> {
}

//===----------------------------------------------------------------------===//
// TileConfigDispatchOp
// IntelAMXTileConfigDispatchOp
//===----------------------------------------------------------------------===//

def Xsmm_TileConfigDispatchOp : Xsmm_GemmLikeOp<"tileConfig.dispatch"> {
let summary = "dispatch for tileConfig operation.";
def Xsmm_IntelAMXTileConfigDispatchOp : Xsmm_GemmLikeOp<"IntelAMXtileConfig.dispatch"> {
let summary = "dispatch for Intel amx tileConfig operation.";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -309,11 +309,11 @@ def Xsmm_FusedBrgemmDispatchOp : Xsmm_Op<"fused_brgemm.dispatch", [Pure]> {


//===----------------------------------------------------------------------===//
// TileConfigOp
// IntelAMXTileConfigOp
//===----------------------------------------------------------------------===//

def Xsmm_TileConfigOp : Xsmm_Op<"tileConfig", [MemoryEffects<[MemWrite, MemRead]>]> {
let summary = "invoke for tileConfig operation.";
def Xsmm_IntelAMXTileConfigOp : Xsmm_Op<"IntelAMXtileConfig", [MemoryEffects<[MemWrite, MemRead]>]> {
let summary = "invoke for Intel AMX tileConfig operation.";
let arguments = (ins I64:$dispatch, Variadic<AnyMemRef>:$inputs);
}

Expand Down
12 changes: 6 additions & 6 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -484,21 +484,21 @@ def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> {
let dependentDialects = ["affine::AffineDialect", "scf::SCFDialect"];
}

def TileConfigInsertionPass : Pass<"tile-config-insertion-pass",
def IntelAMXTileConfigInsertionPass : Pass<"intel-amx-tile-config-insertion-pass",
"func::FuncOp"> {
let summary = "Insert tile configuration xsmm calls";
let summary = "Insert intel amx tile configuration xsmm calls";
let description = [{
Insert tile configuration xsmm calls.
Insert intel amx tile configuration xsmm calls.
}];

let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}

def TileConfigHoistingPass : Pass<"tile-config-hoisting-pass",
def IntelAMXTileConfigHoistingPass : Pass<"intel-amx-tile-config-hoisting-pass",
"func::FuncOp"> {
let summary = "Hoist tile configuration invoke xsmm calls";
let summary = "Hoist intel amx tile configuration invoke xsmm calls";
let description = [{
Run LICM on Tile configuration invoke calls.
Run LICM on intel amx tile configuration invoke calls.
}];

let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
Expand Down
49 changes: 25 additions & 24 deletions lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,16 @@ struct ConvertFusedBrgemmXsmmOp : public OpRewritePattern<FusedBrgemmOp> {
}
};

struct ConvertTileConfigXsmmOp : public OpRewritePattern<TileConfigOp> {
using OpRewritePattern<TileConfigOp>::OpRewritePattern;
struct ConvertIntelAMXTileConfigXsmmOp
: public OpRewritePattern<IntelAMXTileConfigOp> {
using OpRewritePattern<IntelAMXTileConfigOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TileConfigOp tileConfigOp,
LogicalResult matchAndRewrite(IntelAMXTileConfigOp tileConfigOp,
PatternRewriter &rewriter) const override {
std::string funcName = "xsmm_tile_config_invoke";
buildInvokeCall(rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp,
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16));
std::string funcName = "xsmm_intel_amx_tile_config_invoke";
buildInvokeCall(
rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp,
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16));
rewriter.eraseOp(tileConfigOp);
return success();
}
Expand Down Expand Up @@ -208,10 +210,9 @@ static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc,
return call;
}

template <typename OpTy,
typename = std::enable_if_t<
std::is_same<OpTy, xsmm::UnaryDispatchOp>::value ||
std::is_same<OpTy, xsmm::BinaryDispatchOp>::value>>
template <typename OpTy, typename = std::enable_if_t<
std::is_same<OpTy, xsmm::UnaryDispatchOp>::value ||
std::is_same<OpTy, xsmm::BinaryDispatchOp>::value>>
void addKindOperand(RewriterBase &rewriter, OpTy dispatchOp,
SmallVectorImpl<Value> &dispatchOperands,
SmallVectorImpl<Type> &dispatchOperandTypes) {
Expand Down Expand Up @@ -240,13 +241,13 @@ void addKindOperand(RewriterBase &rewriter, FusedBrgemmDispatchOp dispatchOp,
/* do nothing */
}

void addKindOperand(RewriterBase &rewriter, TileConfigDispatchOp dispatchOp,
void addKindOperand(RewriterBase &rewriter,
IntelAMXTileConfigDispatchOp dispatchOp,
SmallVectorImpl<Value> &dispatchOperands,
SmallVectorImpl<Type> &dispatchOperandTypes) {
/* do nothing */
}


static int64_t getOredFlags(ArrayAttr flags) {
int64_t oredFlag = 0;
for (auto flag : flags) {
Expand Down Expand Up @@ -390,13 +391,14 @@ struct ConvertUnaryDispatchOp : public OpRewritePattern<UnaryDispatchOp> {
}
};

struct ConvertTileConfigDispatchOp : public OpRewritePattern<TileConfigDispatchOp> {
using OpRewritePattern<TileConfigDispatchOp>::OpRewritePattern;
struct ConvertIntelAMXTileConfigDispatchOp
: public OpRewritePattern<IntelAMXTileConfigDispatchOp> {
using OpRewritePattern<IntelAMXTileConfigDispatchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TileConfigDispatchOp dispatchOp,
LogicalResult matchAndRewrite(IntelAMXTileConfigDispatchOp dispatchOp,
PatternRewriter &rewriter) const override {
return buildDispatchOp<TileConfigDispatchOp>(rewriter, dispatchOp,
"xsmm_tile_config_dispatch");
return buildDispatchOp<IntelAMXTileConfigDispatchOp>(
rewriter, dispatchOp, "xsmm_intel_amx_tile_config_dispatch");
}
};

Expand All @@ -423,13 +425,12 @@ struct ConvertXsmmToFunc
: public tpp::impl::ConvertXsmmToFuncBase<ConvertXsmmToFunc> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns
.add<ConvertBinaryXsmmOp, ConvertUnaryXsmmOp,
ConvertGemmXsmmOp, ConvertBrgemmXsmmOp, ConvertFusedBrgemmXsmmOp, ConvertTileConfigXsmmOp>(
patterns.getContext());
patterns.add<ConvertBinaryDispatchOp,
ConvertUnaryDispatchOp, ConvertGemmDispatchOp,
ConvertBrgemmDispatchOp, ConvertFusedBrgemmOp, ConvertTileConfigDispatchOp>(
patterns.add<ConvertBinaryXsmmOp, ConvertUnaryXsmmOp, ConvertGemmXsmmOp,
ConvertBrgemmXsmmOp, ConvertFusedBrgemmXsmmOp,
ConvertIntelAMXTileConfigXsmmOp>(patterns.getContext());
patterns.add<ConvertBinaryDispatchOp, ConvertUnaryDispatchOp,
ConvertGemmDispatchOp, ConvertBrgemmDispatchOp,
ConvertFusedBrgemmOp, ConvertIntelAMXTileConfigDispatchOp>(
patterns.getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
Expand Down
12 changes: 6 additions & 6 deletions lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ static void printerDataTypeImpl(OpAsmPrinter &printer, OpTy op) {

template <typename AttrTy>
static void printerFlagsImpl(OpAsmPrinter &printer,
const std::function<ArrayAttr()>& fn,
const std::function<ArrayAttr()> &fn,
const std::string_view &flagsName) {
printer << " " << flagsName << " = (";
llvm::interleaveComma(fn(), printer, [&](auto &flag) {
Expand Down Expand Up @@ -235,15 +235,15 @@ void BinaryDispatchOp::print(OpAsmPrinter &printer) {
printerDataTypeImpl<BinaryDispatchOp>(printer, *this);
}

void TileConfigDispatchOp::print(OpAsmPrinter &printer) {
printerInputImpl<TileConfigDispatchOp>(printer, *this);
void IntelAMXTileConfigDispatchOp::print(OpAsmPrinter &printer) {
printerInputImpl<IntelAMXTileConfigDispatchOp>(printer, *this);
auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); };
printerFlagsImpl<GemmFlagsAttr>(printer, getOpFlags, FLAGS_NAME);
printerDataTypeImpl<TileConfigDispatchOp>(printer, *this);
printerDataTypeImpl<IntelAMXTileConfigDispatchOp>(printer, *this);
}

ParseResult TileConfigDispatchOp::parse(OpAsmParser &parser,
OperationState &result) {
ParseResult IntelAMXTileConfigDispatchOp::parse(OpAsmParser &parser,
OperationState &result) {
if (failed(parseInputImpl(parser, result)) ||
failed(parserFlagsImpl<GemmFlags>(parser, result, FLAGS_NAME)))
return failure();
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ add_mlir_library(TPPTransforms
TransformUtils.cpp
CombineXsmmPass.cpp
SCFParallelLoopTiling.cpp
TileConfig.cpp
TileConfigHoisting.cpp
IntelAMXTileConfig.cpp
IntelAMXTileConfigHoisting.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- TileConfig.cpp -----------------------------------------------------===//
//===- IntelAMXTileConfig.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -19,7 +19,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_TILECONFIGINSERTIONPASS
#define GEN_PASS_DEF_INTELAMXTILECONFIGINSERTIONPASS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir
Expand All @@ -35,10 +35,9 @@ static void appendBrgemmFlags(SmallVector<Attribute> &attributes,
auto flags =
dyn_cast<DispatchOpTy>(opTy.getOperand(0).getDefiningOp()).getFlags();
for (auto flagItr : flags) {
if (flagItr == xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::NONE)) {
if (flagItr ==
xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE))
return;
}
attributes.push_back(flagItr);
}

Expand All @@ -48,7 +47,7 @@ static void appendBrgemmFlags(SmallVector<Attribute> &attributes,
}

template <typename InvokeOpTy, typename DispatchOpTy>
struct TileConfig : OpRewritePattern<InvokeOpTy> {
struct IntelAMXTileConfig : OpRewritePattern<InvokeOpTy> {
using OpRewritePattern<InvokeOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(InvokeOpTy op,
Expand All @@ -58,22 +57,20 @@ struct TileConfig : OpRewritePattern<InvokeOpTy> {
return failure();
auto flags =
dyn_cast<DispatchOpTy>(op.getOperand(0).getDefiningOp()).getFlags();
for (auto flagItr : flags) {
for (auto flagItr : flags)
if (flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG) ||
flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) {
mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG))
return failure();
}
}

SmallVector<Attribute> attributesSetup;
attributesSetup.push_back(xsmm::GemmFlagsAttr::get(
rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG));
appendBrgemmFlags<InvokeOpTy, DispatchOpTy>(attributesSetup, rewriter, op);
auto tileConfigSetup = rewriter.create<xsmm::TileConfigDispatchOp>(
auto tileConfigSetup = rewriter.create<xsmm::IntelAMXTileConfigDispatchOp>(
op.getLoc(), rewriter.getI64Type(),
DenseI64ArrayAttr::get(
rewriter.getContext(),
Expand All @@ -86,7 +83,7 @@ struct TileConfig : OpRewritePattern<InvokeOpTy> {
attributesReset.push_back(xsmm::GemmFlagsAttr::get(
rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG));
appendBrgemmFlags<InvokeOpTy, DispatchOpTy>(attributesReset, rewriter, op);
auto tileConfigReset = rewriter.create<xsmm::TileConfigDispatchOp>(
auto tileConfigReset = rewriter.create<xsmm::IntelAMXTileConfigDispatchOp>(
op.getLoc(), rewriter.getI64Type(),
DenseI64ArrayAttr::get(
rewriter.getContext(),
Expand All @@ -110,8 +107,8 @@ struct TileConfig : OpRewritePattern<InvokeOpTy> {
op.getLoc(), MemRefType::get({64}, rewriter.getI8Type()));

ValueRange tileConfigInputs{alloca};
rewriter.create<mlir::xsmm::TileConfigOp>(op.getLoc(), tileConfigSetup,
tileConfigInputs);
rewriter.create<mlir::xsmm::IntelAMXTileConfigOp>(
op.getLoc(), tileConfigSetup, tileConfigInputs);

SmallVector<Value> invokeOperands;
invokeOperands.push_back(dispatch);
Expand All @@ -124,22 +121,23 @@ struct TileConfig : OpRewritePattern<InvokeOpTy> {
invokeOperands);

ValueRange tileResetInputs{alloca};
rewriter.create<mlir::xsmm::TileConfigOp>(op.getLoc(), tileConfigReset,
tileResetInputs);
rewriter.create<mlir::xsmm::IntelAMXTileConfigOp>(
op.getLoc(), tileConfigReset, tileResetInputs);

// rewriter.create<memref::DeallocOp>(op.getLoc(), alloca);
rewriter.eraseOp(op);
rewriter.eraseOp(op.getOperand(0).getDefiningOp());
return success();
}
};

struct TileConfigInsertionPass
: public impl::TileConfigInsertionPassBase<TileConfigInsertionPass> {
struct IntelAMXTileConfigInsertionPass
: public impl::IntelAMXTileConfigInsertionPassBase<
IntelAMXTileConfigInsertionPass> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<TileConfig<xsmm::BrgemmOp, xsmm::BrgemmDispatchOp>>(
patterns.add<IntelAMXTileConfig<xsmm::BrgemmOp, xsmm::BrgemmDispatchOp>>(
patterns.getContext());
patterns.add<TileConfig<xsmm::FusedBrgemmOp, xsmm::FusedBrgemmDispatchOp>>(
patterns.add<
IntelAMXTileConfig<xsmm::FusedBrgemmOp, xsmm::FusedBrgemmDispatchOp>>(
patterns.getContext());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- TileConfigHoisting.cpp ---------------------------------------------===//
//===- IntelAMXTileConfigHoisting.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -10,16 +10,14 @@
//
//===----------------------------------------------------------------------===//
#include "TPP/Dialect/Xsmm/XsmmOps.h"
#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_TILECONFIGHOISTINGPASS
#define GEN_PASS_DEF_INTELAMXTILECONFIGHOISTINGPASS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir
Expand All @@ -30,31 +28,32 @@ using namespace mlir::xsmm;
namespace mlir {
namespace tpp {

struct TileConfigHoisting : OpRewritePattern<memref::AllocaOp> {
struct IntelAMXTileConfigHoisting : OpRewritePattern<memref::AllocaOp> {
using OpRewritePattern<memref::AllocaOp>::OpRewritePattern;

LogicalResult matchAndRewrite(memref::AllocaOp alloca,
PatternRewriter &rewriter) const override {

xsmm::TileConfigOp firstTileConfig, secondTileConfig;
xsmm::IntelAMXTileConfigOp firstTileConfig, secondTileConfig;
for (auto *user : alloca->getUsers()) {
if (!dyn_cast<xsmm::TileConfigOp>(user)) {
if (!dyn_cast<xsmm::IntelAMXTileConfigOp>(user)) {
return failure();
}
auto flags =
dyn_cast<xsmm::TileConfigDispatchOp>(
dyn_cast<xsmm::TileConfigOp>(user).getOperand(0).getDefiningOp())
.getFlags();
auto flags = dyn_cast<xsmm::IntelAMXTileConfigDispatchOp>(
dyn_cast<xsmm::IntelAMXTileConfigOp>(user)
.getOperand(0)
.getDefiningOp())
.getFlags();
for (auto flagItr : flags) {
if (flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG)) {
firstTileConfig = dyn_cast<xsmm::TileConfigOp>(user);
firstTileConfig = dyn_cast<xsmm::IntelAMXTileConfigOp>(user);

} else if (flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) {
secondTileConfig = dyn_cast<xsmm::TileConfigOp>(user);
secondTileConfig = dyn_cast<xsmm::IntelAMXTileConfigOp>(user);
}
}
}
Expand Down Expand Up @@ -85,10 +84,11 @@ struct TileConfigHoisting : OpRewritePattern<memref::AllocaOp> {
}
};

struct TileConfigHoistingPass
: public impl::TileConfigHoistingPassBase<TileConfigHoistingPass> {
struct IntelAMXTileConfigHoistingPass
: public impl::IntelAMXTileConfigHoistingPassBase<
IntelAMXTileConfigHoistingPass> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<TileConfigHoisting>(patterns.getContext());
patterns.add<IntelAMXTileConfigHoisting>(patterns.getContext());
}

void runOnOperation() override {
Expand Down
Loading

0 comments on commit 96d146d

Please sign in to comment.