diff --git a/include/TPP/Dialect/Xsmm/XsmmEnum.td b/include/TPP/Dialect/Xsmm/XsmmEnum.td index a6d67c389..71a13e425 100644 --- a/include/TPP/Dialect/Xsmm/XsmmEnum.td +++ b/include/TPP/Dialect/Xsmm/XsmmEnum.td @@ -76,7 +76,9 @@ def Xsmm_GemmFlags : I64EnumAttr< I64EnumAttrCase<"BETA_0", 4, "beta_0">, I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">, I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">, - I64EnumAttrCase<"VNNI_C", 8192, "vnni_c"> - ]> { + I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">, + I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">, + I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig"> + ]> { let cppNamespace = "mlir::xsmm"; } diff --git a/include/TPP/Dialect/Xsmm/XsmmOps.td b/include/TPP/Dialect/Xsmm/XsmmOps.td index b9cfad133..0edb2da4c 100644 --- a/include/TPP/Dialect/Xsmm/XsmmOps.td +++ b/include/TPP/Dialect/Xsmm/XsmmOps.td @@ -263,6 +263,15 @@ def Xsmm_GemmDispatchOp : Xsmm_GemmLikeOp<"gemm.dispatch"> { def Xsmm_BrgemmDispatchOp : Xsmm_GemmLikeOp<"brgemm.dispatch"> { let summary = "dispatch for brgemm operation."; let hasVerifier = 1; + +} + +//===----------------------------------------------------------------------===// +// IntelAMXTileConfigDispatchOp +//===----------------------------------------------------------------------===// + +def Xsmm_IntelAMXTileConfigDispatchOp : Xsmm_GemmLikeOp<"IntelAMXtileConfig.dispatch"> { + let summary = "dispatch for Intel amx tileConfig operation."; } //===----------------------------------------------------------------------===// @@ -298,4 +307,14 @@ def Xsmm_FusedBrgemmDispatchOp : Xsmm_Op<"fused_brgemm.dispatch", [Pure]> { let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// IntelAMXTileConfigOp +//===----------------------------------------------------------------------===// + +def Xsmm_IntelAMXTileConfigOp : Xsmm_Op<"IntelAMXtileConfig", [MemoryEffects<[MemWrite, MemRead]>]> { + let summary = "invoke for Intel AMX tileConfig operation."; + let arguments = (ins I64:$dispatch, Variadic:$inputs); +} + #endif // TPP_XSMM_OPS diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index dc51cfb08..8df063e6b 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -204,7 +204,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> { let options= [ Option<"linalgToLoops", "linalg-to-loops", "bool", /*default=*/"false", - "Skip all TPP transformations. Lower linalg directly to loops."> + "Skip all TPP transformations. Lower linalg directly to loops.">, + ListOption<"parallelTaskGrid", "parallel-task-grid", + "unsigned", "Grid-sizes for parallel tasks."> + ]; } @@ -289,6 +292,12 @@ def LocalDialectsLowering : Pass<"lower-local-dialects", "ModuleOp"> { "tensor::TensorDialect", "xsmm::XsmmDialect", "LLVM::LLVMDialect"]; + let options = [ + ListOption<"parallelTaskGrid", "parallel-task-grid", + "unsigned", "Grid-sizes for parallel tasks."> + + ]; + } def Postprocessing : Pass<"postprocess", "func::FuncOp"> { @@ -470,10 +479,11 @@ def FoldXsmmFlags : Pass<"fold-xsmm-flags", "func::FuncOp"> { let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; } + def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> { let summary = "Tile parallel loops"; let options = [ - ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t", + ListOption<"tileSizes", "parallel-loop-tile-sizes", "unsigned", "Factors to tile parallel loops by">, Option<"noMinMaxBounds", "no-min-max-bounds", "bool", /*default=*/"false", @@ -494,4 +504,24 @@ def GpuInlineConstants : Pass<"gpu-inline-constants", "func::FuncOp"> { "arith::ArithDialect"]; } +def IntelAMXTileConfigInsertionPass : Pass<"intel-amx-tile-config-insertion-pass", + "func::FuncOp"> { + let summary = "Insert intel amx tile configuration xsmm calls"; + let description = [{ + Insert intel amx tile configuration xsmm calls. + }]; + + let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; +} + +def IntelAMXTileConfigHoistingPass : Pass<"intel-amx-tile-config-hoisting-pass", + "func::FuncOp"> { + let summary = "Hoist intel amx tile configuration invoke xsmm calls"; + let description = [{ + Run LICM on intel amx tile configuration invoke calls. + }]; + + let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; +} + #endif // TPP_DIALECT_TPP_PASSES diff --git a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp b/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp index 737ee27e1..84183c3f8 100644 --- a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp +++ b/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp @@ -171,6 +171,21 @@ struct ConvertFusedBrgemmXsmmOp : public OpRewritePattern { } }; +struct ConvertIntelAMXTileConfigXsmmOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IntelAMXTileConfigOp tileConfigOp, + PatternRewriter &rewriter) const override { + std::string funcName = "xsmm_intel_amx_tile_config_invoke"; + buildInvokeCall( + rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp, + xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)); + rewriter.eraseOp(tileConfigOp); + return success(); + } +}; + static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, ArrayRef dispatchOperands, ArrayRef dispatchOperandTypes, @@ -195,10 +210,9 @@ static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, return call; } -template ::value || - std::is_same::value>> +template ::value || + std::is_same::value>> void addKindOperand(RewriterBase &rewriter, OpTy dispatchOp, SmallVectorImpl &dispatchOperands, SmallVectorImpl &dispatchOperandTypes) { @@ -227,6 +241,13 @@ void addKindOperand(RewriterBase &rewriter, FusedBrgemmDispatchOp dispatchOp, /* do nothing */ } +void addKindOperand(RewriterBase &rewriter, + IntelAMXTileConfigDispatchOp dispatchOp, + SmallVectorImpl &dispatchOperands, + SmallVectorImpl &dispatchOperandTypes) { + /* do nothing */ +} + static int64_t getOredFlags(ArrayAttr flags) { int64_t oredFlag = 0; for (auto flag : flags) { @@ -370,6 +391,17 @@ struct ConvertUnaryDispatchOp : public OpRewritePattern { } }; +struct ConvertIntelAMXTileConfigDispatchOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IntelAMXTileConfigDispatchOp dispatchOp, + PatternRewriter &rewriter) const override { + return buildDispatchOp( + rewriter, dispatchOp, "xsmm_intel_amx_tile_config_dispatch"); + } +}; + struct ConvertFusedBrgemmOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -393,13 +425,12 @@ struct ConvertXsmmToFunc : public tpp::impl::ConvertXsmmToFuncBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns - .add( - patterns.getContext()); - patterns.add( + patterns.add(patterns.getContext()); + patterns.add( patterns.getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index 65628778b..b5c20285c 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -47,10 +47,10 @@ llvm::cl::opt llvm::cl::init(false)); // Control grid parallelism sizes. -llvm::cl::list +llvm::cl::list parallelTaskGrid("parallel-task-grid", llvm::cl::desc("Grid-sizes for parallel tasks"), - llvm::cl::list_init(SmallVector{2, 8}), + llvm::cl::list_init(SmallVector{2, 8}), llvm::cl::CommaSeparated); namespace mlir { @@ -127,7 +127,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend})); } else { // Apply the default preprocessing pass - DefaultTppPassesOptions tppDefaultOptions{linalgToLoops}; + DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, + parallelTaskGrid}; pm.addPass(createDefaultTppPasses(tppDefaultOptions)); } @@ -140,12 +141,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(tpp::createConvertPerfToFunc()); pm.addPass(createConvertTensorToLinalgPass()); pm.addNestedPass(createConvertLinalgToLoopsPass()); - if (defParallel) { - mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; - tilingOptions.tileSizes = parallelTaskGrid; - pm.addPass(createSCFParallelLoopTiling(tilingOptions)); + if (defParallel) pm.addPass(createConvertSCFToOpenMPPass()); - } pm.addPass(createConvertVectorToSCFPass()); pm.addPass(arith::createArithExpandOpsPass()); pm.addPass(createLowerAffinePass()); diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index b7905cf24..fec85685d 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -77,6 +77,10 @@ struct LocalDialectsLowering : public tpp::impl::LocalDialectsLoweringBase, UtilityPassBase { + LocalDialectsLowering() {} + LocalDialectsLowering(const LocalDialectsLoweringOptions &options) { + parallelTaskGrid = options.parallelTaskGrid; + } void runOnOperation() override { auto module = getOperation(); @@ -106,6 +110,16 @@ struct LocalDialectsLowering // that they are hoisted out of loops. pm.addNestedPass(createCleanup()); + mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; + tilingOptions.tileSizes = parallelTaskGrid; + pm.addPass(createSCFParallelLoopTiling(tilingOptions)); + + pm.addNestedPass(createIntelAMXTileConfigInsertionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createLoopInvariantCodeMotionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createIntelAMXTileConfigHoistingPass()); + pm.addPass(createConvertXsmmToFunc()); pm.addPass(createConvertPerfToFunc()); } @@ -310,7 +324,8 @@ struct DefaultTppPasses pm.addPass(createConvertForAllToParallelOp()); // Covert all local TPP-related dialects. - pm.addPass(createLocalDialectsLowering()); + LocalDialectsLoweringOptions localDialectsLowering{parallelTaskGrid}; + pm.addPass(createLocalDialectsLowering(localDialectsLowering)); // Clean up after the default pipeline. pm.addNestedPass(createPostprocessing()); diff --git a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp index 1b527743a..139c13219 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp @@ -179,7 +179,7 @@ static void printerDataTypeImpl(OpAsmPrinter &printer, OpTy op) { template static void printerFlagsImpl(OpAsmPrinter &printer, - const std::function& fn, + const std::function &fn, const std::string_view &flagsName) { printer << " " << flagsName << " = ("; llvm::interleaveComma(fn(), printer, [&](auto &flag) { @@ -235,6 +235,21 @@ void BinaryDispatchOp::print(OpAsmPrinter &printer) { printerDataTypeImpl(printer, *this); } +void IntelAMXTileConfigDispatchOp::print(OpAsmPrinter &printer) { + printerInputImpl(printer, *this); + auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; + printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); + printerDataTypeImpl(printer, *this); +} + +ParseResult IntelAMXTileConfigDispatchOp::parse(OpAsmParser &parser, + OperationState &result) { + if (failed(parseInputImpl(parser, result)) || + failed(parserFlagsImpl(parser, result, FLAGS_NAME))) + return failure(); + return parseDataTypeImpl(parser, result); +} + template static LogicalResult verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op, diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index f3644e3ec..53d60a390 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -17,6 +17,8 @@ add_mlir_library(TPPTransforms TransformUtils.cpp CombineXsmmPass.cpp SCFParallelLoopTiling.cpp + IntelAMXTileConfig.cpp + IntelAMXTileConfigHoisting.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/IntelAMXTileConfig.cpp b/lib/TPP/Transforms/IntelAMXTileConfig.cpp new file mode 100644 index 000000000..0f90a1446 --- /dev/null +++ b/lib/TPP/Transforms/IntelAMXTileConfig.cpp @@ -0,0 +1,152 @@ +//===- IntelAMXTileConfig.cpp ------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file inserts tile configuration calls. +// +//===----------------------------------------------------------------------===// +#include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "TPP/Dialect/Xsmm/XsmmUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_INTELAMXTILECONFIGINSERTIONPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; + +namespace mlir { +namespace tpp { + +template +static void appendBrgemmFlags(SmallVector &attributes, + PatternRewriter &rewriter, InvokeOpTy opTy) { + auto flags = + dyn_cast(opTy.getOperand(0).getDefiningOp()).getFlags(); + for (auto flagItr : flags) { + if (flagItr == + xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE)) + return; + attributes.push_back(flagItr); + } + + if (attributes.empty()) + attributes.push_back( + xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE)); +} + +template +struct IntelAMXTileConfig : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InvokeOpTy op, + PatternRewriter &rewriter) const override { + if (xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()) != + xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)) + return failure(); + auto flags = + dyn_cast(op.getOperand(0).getDefiningOp()).getFlags(); + for (auto flagItr : flags) + if (flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG) || + flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) + return failure(); + + SmallVector attributesSetup; + attributesSetup.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG)); + appendBrgemmFlags(attributesSetup, rewriter, op); + auto tileConfigSetup = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + DenseI64ArrayAttr::get( + rewriter.getContext(), + dyn_cast(op.getOperand(0).getDefiningOp()) + .getInputs()), + rewriter.getArrayAttr(attributesSetup), + xsmm::utils::getDataType(rewriter, op.getOperand(1).getType())); + + SmallVector attributesReset; + attributesReset.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG)); + appendBrgemmFlags(attributesReset, rewriter, op); + auto tileConfigReset = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + DenseI64ArrayAttr::get( + rewriter.getContext(), + dyn_cast(op.getOperand(0).getDefiningOp()) + .getInputs()), + rewriter.getArrayAttr(attributesReset), + xsmm::utils::getDataType(rewriter, op.getOperand(1).getType())); + + SmallVector attributesBrgemm; + attributesBrgemm.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG)); + attributesBrgemm.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG)); + appendBrgemmFlags(attributesBrgemm, rewriter, op); + + auto dispatch = dyn_cast( + rewriter.clone(*op.getOperand(0).getDefiningOp())); + dispatch.setFlagsAttr(rewriter.getArrayAttr(attributesBrgemm)); + + auto alloca = rewriter.create( + op.getLoc(), MemRefType::get({64}, rewriter.getI8Type())); + + ValueRange tileConfigInputs{alloca}; + rewriter.create( + op.getLoc(), tileConfigSetup, tileConfigInputs); + + SmallVector invokeOperands; + invokeOperands.push_back(dispatch); + auto opItr = op->getOperands().begin(); + std::advance(opItr, 1); + invokeOperands.append(opItr, op->getOperands().end()); + rewriter.create( + op.getLoc(), + xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()), + invokeOperands); + + ValueRange tileResetInputs{alloca}; + rewriter.create( + op.getLoc(), tileConfigReset, tileResetInputs); + + rewriter.eraseOp(op); + if (op.getOperand(0).getDefiningOp()->getUsers().empty()) + rewriter.eraseOp(op.getOperand(0).getDefiningOp()); + return success(); + } +}; + +struct IntelAMXTileConfigInsertionPass + : public impl::IntelAMXTileConfigInsertionPassBase< + IntelAMXTileConfigInsertionPass> { + void populateCombinePatterns(RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext()); + patterns.add< + IntelAMXTileConfig>( + patterns.getContext()); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateCombinePatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace tpp +} // namespace mlir diff --git a/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp new file mode 100644 index 000000000..df0ba42c1 --- /dev/null +++ b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp @@ -0,0 +1,101 @@ +//===- IntelAMXTileConfigHoisting.cpp ----------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements tile configuration hoisting on parallel loops. +// +//===----------------------------------------------------------------------===// +#include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_INTELAMXTILECONFIGHOISTINGPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; +using namespace mlir::xsmm; + +namespace mlir { +namespace tpp { + +struct IntelAMXTileConfigHoisting : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocaOp alloca, + PatternRewriter &rewriter) const override { + + xsmm::IntelAMXTileConfigOp firstTileConfig, secondTileConfig; + for (auto *user : alloca->getUsers()) { + if (!dyn_cast(user)) { + return failure(); + } + auto flags = dyn_cast( + dyn_cast(user) + .getOperand(0) + .getDefiningOp()) + .getFlags(); + for (auto flagItr : flags) { + if (flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG)) { + firstTileConfig = dyn_cast(user); + + } else if (flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) { + secondTileConfig = dyn_cast(user); + } + } + } + + scf::ParallelOp parallelOpParent = NULL; + auto op = alloca.getOperation(); + while (op) { + if (op->getParentOfType()) { + if (&op->getParentOfType().getRegion() == + alloca->getParentRegion()) { + return failure(); + } + parallelOpParent = op->getParentOfType(); + break; + } + op = op->getParentOp(); + } + + if (parallelOpParent == NULL) + return failure(); + + rewriter.moveOpBefore(alloca, parallelOpParent.getBody(), + parallelOpParent.getBody()->begin()); + rewriter.moveOpAfter(firstTileConfig, alloca); + rewriter.moveOpBefore(secondTileConfig, parallelOpParent.getBody(), + std::prev(parallelOpParent.getBody()->end(), 1)); + return success(); + } +}; + +struct IntelAMXTileConfigHoistingPass + : public impl::IntelAMXTileConfigHoistingPassBase< + IntelAMXTileConfigHoistingPass> { + void populateCombinePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateCombinePatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace tpp +} // namespace mlir diff --git a/lib/TPP/Transforms/SCFParallelLoopTiling.cpp b/lib/TPP/Transforms/SCFParallelLoopTiling.cpp index f6632ed1a..8379e1293 100644 --- a/lib/TPP/Transforms/SCFParallelLoopTiling.cpp +++ b/lib/TPP/Transforms/SCFParallelLoopTiling.cpp @@ -54,7 +54,7 @@ using namespace mlir::scf; /// %i0 + j0 and %i1 + %j1. /// /// The old loop is replaced with the new one. -void tileParallelLoop(ParallelOp op, ArrayRef tileSizes, +void tileParallelLoop(ParallelOp op, ArrayRef tileSizes, bool noMinMaxBounds) { bool useParallelOp = false; /* TODO, need to implement this case */ @@ -244,7 +244,7 @@ namespace { struct SCFParallelLoopTiling : public tpp::impl::SCFParallelLoopTilingBase { SCFParallelLoopTiling(){}; - SCFParallelLoopTiling(ArrayRef tileSizes, + SCFParallelLoopTiling(ArrayRef tileSizes, bool noMinMaxBounds = false) { this->tileSizes = tileSizes; this->noMinMaxBounds = noMinMaxBounds; diff --git a/runtime/Xsmm/XsmmRunnerUtils.cpp b/runtime/Xsmm/XsmmRunnerUtils.cpp index b1ba23eb4..62b784f1f 100644 --- a/runtime/Xsmm/XsmmRunnerUtils.cpp +++ b/runtime/Xsmm/XsmmRunnerUtils.cpp @@ -14,6 +14,7 @@ #include "XsmmRunnerUtils.h" #include "libxsmm.h" // NOLINT [build/include_subdir] +#include "libxsmm_utils.h" // Helper function prototypes. static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, @@ -209,6 +210,41 @@ xsmm_binary_dispatch(const libxsmm_meltw_binary_type op_type, return reinterpret_cast(kernel); } +extern "C" int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype dtype, int64_t m, int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_cfg_flags = flags; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + + auto sgemm = libxsmm_dispatch_tilecfg_gemm(l_shape, l_cfg_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate tileconfig func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + extern "C" void xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, void *alignedPtrIn, int64_t offsetIn, void *alignedPtrOut, int64_t offsetOut) { @@ -311,8 +347,8 @@ extern "C" int64_t xsmm_brgemm_dispatch(const libxsmm_datatype dtype, int64_t m, l_brconfig.br_stride_b_hint = stride_a * typeSize; l_brconfig.br_unroll_hint = 0; - auto sgemm = libxsmm_dispatch_brgemm(l_shape, l_flags, l_prefetch_flags, - l_brconfig); + auto sgemm = + libxsmm_dispatch_brgemm(l_shape, l_flags, l_prefetch_flags, l_brconfig); if (!sgemm) { fprintf(stderr, "failed to generate brgemm func\n"); fprintf(stderr, "dtype: %u\n", dtype); @@ -407,8 +443,8 @@ xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, l_postops.d_binary_type = binary_op_type; l_postops.ldd = ldc; - auto sgemm = libxsmm_dispatch_brgemm_ext( - l_shape, l_flags, l_prefetch_flags, l_brconfig, l_argops, l_postops); + auto sgemm = libxsmm_dispatch_brgemm_ext(l_shape, l_flags, l_prefetch_flags, + l_brconfig, l_argops, l_postops); if (!sgemm) { fprintf(stderr, "failed to generate fused brgemm func\n"); fprintf(stderr, "data_type: %u\n", data_type); @@ -420,6 +456,18 @@ xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, return reinterpret_cast(sgemm); } +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *tileState, int64_t offset) { + libxsmm_xmmfunction cfg_tr; + + libxsmm_tilecfg_state *l_tilestate = + reinterpret_cast(tileState); + + cfg_tr.tilecfg = reinterpret_cast(addr); + cfg_tr.tilecfg(l_tilestate); +} + static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, FILE *outfile) { fprintf(outfile, "M: %d\n", gemmShape.m); diff --git a/runtime/Xsmm/XsmmRunnerUtils.h b/runtime/Xsmm/XsmmRunnerUtils.h index 329b0a7c6..3d5048188 100644 --- a/runtime/Xsmm/XsmmRunnerUtils.h +++ b/runtime/Xsmm/XsmmRunnerUtils.h @@ -44,6 +44,10 @@ extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_fused_brgemm_dispatch( const libxsmm_meltw_binary_flags binary_flags, const libxsmm_meltw_binary_type binary_op_type); +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_gemm_invoke(const libxsmm_datatype dType, int64_t addr, void *alignedPtrA, int64_t offsetA, void *alignedPtrB, int64_t offsetB, @@ -74,4 +78,8 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_fused_brgemm_invoke( int64_t offsetA, void *alignedPtrB, int64_t offsetB, void *alignedPtrC, int64_t offsetC, void *alignedPtrD, int64_t offsetD, int64_t numBatches); +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offset); + #endif // TPP_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/test/Passes/pass-convert-gemm-to-parallel-tile.mlir b/test/Passes/pass-convert-gemm-to-parallel-tile.mlir index 8efc7ab18..e6f75a897 100644 --- a/test/Passes/pass-convert-gemm-to-parallel-tile.mlir +++ b/test/Passes/pass-convert-gemm-to-parallel-tile.mlir @@ -31,7 +31,7 @@ module { // CHECK: omp.wsloop for (%[[ARG3:.*]], %[[ARG4:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { // CHECK: memref.alloca_scope { // CHECK: scf.for %[[ARG5:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG5]], %[[ARG3]] : index // CHECK: scf.for %[[ARG6:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { -// CHECK: %[[temp1:.*]] = arith.addi %[[ARG5]], %[[ARG3]] : index // CHECK: %[[temp2:.*]] = arith.addi %[[ARG6]], %[[ARG4]] : index diff --git a/test/Passes/pass-convert-mlp-to-parallel-tile.mlir b/test/Passes/pass-convert-mlp-to-parallel-tile.mlir index c941c15cd..2227dc44f 100644 --- a/test/Passes/pass-convert-mlp-to-parallel-tile.mlir +++ b/test/Passes/pass-convert-mlp-to-parallel-tile.mlir @@ -82,22 +82,22 @@ module { //CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) { //CHECK: memref.alloca_scope { //CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] { -//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index //CHECK: omp.parallel { //CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) { //CHECK: memref.alloca_scope { //CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] { -//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index //CHECK: omp.parallel { //CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) { //CHECK: memref.alloca_scope { //CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] { -//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index diff --git a/test/Passes/pass-tileconfig-hoisting-pass.mlir b/test/Passes/pass-tileconfig-hoisting-pass.mlir new file mode 100644 index 000000000..5e4786476 --- /dev/null +++ b/test/Passes/pass-tileconfig-hoisting-pass.mlir @@ -0,0 +1,132 @@ +// RUN: tpp-opt %s --intel-amx-tile-config-hoisting-pass | FileCheck %s + +module{ + +memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + +func.func @entry(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c32_i64 = arith.constant 32 : i64 + %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %1 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 + %2 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + %3 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { + scf.for %arg3 = %c0 to %c2 step %c1 { + %10 = arith.addi %arg3, %arg1 : index + %subview = memref.subview %arg0[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + scf.for %arg4 = %c0 to %c8 step %c1 { + %11 = arith.addi %arg4, %arg2 : index + %subview_1 = memref.subview %alloc[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %alloca = memref.alloca() : memref<64xi8> + "xsmm.IntelAMXtileConfig"(%1, %alloca) : (i64, memref<64xi8>) -> () + xsmm.brgemm(data_type = bf16, %3, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + "xsmm.IntelAMXtileConfig"(%2, %alloca) : (i64, memref<64xi8>) -> () + } + } + scf.reduce + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %4 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 + %5 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + %6 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { + scf.for %arg3 = %c0 to %c2 step %c1 { + %10 = arith.addi %arg3, %arg1 : index + %subview = memref.subview %alloc[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + scf.for %arg4 = %c0 to %c8 step %c1 { + %11 = arith.addi %arg4, %arg2 : index + %subview_1 = memref.subview %alloc_0[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %alloca = memref.alloca() : memref<64xi8> + "xsmm.IntelAMXtileConfig"(%4, %alloca) : (i64, memref<64xi8>) -> () + xsmm.brgemm(data_type = bf16, %6, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + "xsmm.IntelAMXtileConfig"(%5, %alloca) : (i64, memref<64xi8>) -> () + } + } + scf.reduce + } + %7 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 + %8 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + %9 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { + scf.for %arg3 = %c0 to %c2 step %c1 { + %10 = arith.addi %arg3, %arg1 : index + %subview = memref.subview %alloc_0[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + scf.for %arg4 = %c0 to %c8 step %c1 { + %11 = arith.addi %arg4, %arg2 : index + %subview_1 = memref.subview %alloc[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %alloca = memref.alloca() : memref<64xi8> + "xsmm.IntelAMXtileConfig"(%7, %alloca) : (i64, memref<64xi8>) -> () + xsmm.brgemm(data_type = bf16, %9, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + "xsmm.IntelAMXtileConfig"(%8, %alloca) : (i64, memref<64xi8>) -> () + } + } + scf.reduce + } + memref.dealloc %alloc_0 : memref<8x32x32x32xbf16> + return %alloc : memref<8x32x32x32xbf16> +} +} + +// CHECK-LABEL: func.func @entry( +// CHECK: %[[ARG0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 +// CHECK: %[[dispatch1:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[dispatch2:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[brgemmdispatch:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch1]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { +// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: } +// CHECK: } +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch2]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.reduce +// CHECK: } +// CHECK: %[[dispatch3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[dispatch4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[brgemmdispatch2:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { +// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: } +// CHECK: } +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch4]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.reduce +// CHECK: } +// CHECK: %[[dispatch5:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[dispatch6:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[brgemmdispatch3:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch5]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { +// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch3]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: } +// CHECK: } +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch6]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.reduce +// CHECK: } + diff --git a/test/Passes/pass-tileconfig-insertion.mlir b/test/Passes/pass-tileconfig-insertion.mlir new file mode 100644 index 000000000..5ccfdd87d --- /dev/null +++ b/test/Passes/pass-tileconfig-insertion.mlir @@ -0,0 +1,113 @@ +// RUN: tpp-opt %s --intel-amx-tile-config-insertion-pass | FileCheck %s + +module { + memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @entry(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c32_i64 = arith.constant 32 : i64 + %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %c0_0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c8_1 = arith.constant 8 : index + %2 = arith.muli %c1, %c2 : index + %3 = arith.muli %c1, %c8_1 : index + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%2, %3) { + scf.for %arg3 = %c0_0 to %2 step %c1 { + scf.for %arg4 = %c0_0 to %3 step %c1 { + %8 = arith.addi %arg3, %arg1 : index + %9 = arith.addi %arg4, %arg2 : index + %subview = memref.subview %alloc[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_9 = memref.subview %arg0[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + } + } + scf.reduce + } + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %c0_3 = arith.constant 0 : index + %c2_4 = arith.constant 2 : index + %c8_5 = arith.constant 8 : index + %4 = arith.muli %c1, %c2_4 : index + %5 = arith.muli %c1, %c8_5 : index + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%4, %5) { + scf.for %arg3 = %c0_3 to %4 step %c1 { + scf.for %arg4 = %c0_3 to %5 step %c1 { + %8 = arith.addi %arg3, %arg1 : index + %9 = arith.addi %arg4, %arg2 : index + %subview = memref.subview %alloc_2[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_9 = memref.subview %alloc[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + } + } + scf.reduce + } + %c0_6 = arith.constant 0 : index + %c2_7 = arith.constant 2 : index + %c8_8 = arith.constant 8 : index + %6 = arith.muli %c1, %c2_7 : index + %7 = arith.muli %c1, %c8_8 : index + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%6, %7) { + scf.for %arg3 = %c0_6 to %6 step %c1 { + scf.for %arg4 = %c0_6 to %7 step %c1 { + %8 = arith.addi %arg3, %arg1 : index + %9 = arith.addi %arg4, %arg2 : index + %subview = memref.subview %alloc[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_9 = memref.subview %alloc_2[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + } + } + scf.reduce + } + memref.dealloc %alloc_2 : memref<8x32x32x32xbf16> + return %alloc : memref<8x32x32x32xbf16> + } +} + +// CHECK:func.func @entry(%[[ARG0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> ()