From b1592587c40a8cc47b16f9b4f06f8f8a8992c0b3 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Tue, 20 Feb 2024 05:03:11 -0800 Subject: [PATCH 01/10] Tile configuration addition pass --- include/TPP/Dialect/Xsmm/XsmmEnum.td | 6 +- include/TPP/Dialect/Xsmm/XsmmOps.td | 19 +++ include/TPP/Passes.td | 11 ++ .../ConvertXsmmToFunc/ConvertXsmmToFunc.cpp | 34 +++- lib/TPP/DefaultPipeline.cpp | 1 + lib/TPP/Dialect/Xsmm/XsmmOps.cpp | 15 ++ lib/TPP/Transforms/CMakeLists.txt | 1 + lib/TPP/Transforms/TileConfig.cpp | 154 ++++++++++++++++++ runtime/Xsmm/XsmmRunnerUtils.cpp | 58 ++++++- runtime/Xsmm/XsmmRunnerUtils.h | 8 + 10 files changed, 299 insertions(+), 8 deletions(-) create mode 100644 lib/TPP/Transforms/TileConfig.cpp diff --git a/include/TPP/Dialect/Xsmm/XsmmEnum.td b/include/TPP/Dialect/Xsmm/XsmmEnum.td index a6d67c389..71a13e425 100644 --- a/include/TPP/Dialect/Xsmm/XsmmEnum.td +++ b/include/TPP/Dialect/Xsmm/XsmmEnum.td @@ -76,7 +76,9 @@ def Xsmm_GemmFlags : I64EnumAttr< I64EnumAttrCase<"BETA_0", 4, "beta_0">, I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">, I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">, - I64EnumAttrCase<"VNNI_C", 8192, "vnni_c"> - ]> { + I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">, + I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">, + I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig"> + ]> { let cppNamespace = "mlir::xsmm"; } diff --git a/include/TPP/Dialect/Xsmm/XsmmOps.td b/include/TPP/Dialect/Xsmm/XsmmOps.td index b9cfad133..77efa27c5 100644 --- a/include/TPP/Dialect/Xsmm/XsmmOps.td +++ b/include/TPP/Dialect/Xsmm/XsmmOps.td @@ -263,6 +263,15 @@ def Xsmm_GemmDispatchOp : Xsmm_GemmLikeOp<"gemm.dispatch"> { def Xsmm_BrgemmDispatchOp : Xsmm_GemmLikeOp<"brgemm.dispatch"> { let summary = "dispatch for brgemm operation."; let hasVerifier = 1; + +} + +//===----------------------------------------------------------------------===// +// TileConfigDispatchOp +//===----------------------------------------------------------------------===// + +def Xsmm_TileConfigDispatchOp : Xsmm_GemmLikeOp<"tileConfig.dispatch"> { + let summary = "dispatch for tileConfig operation."; } //===----------------------------------------------------------------------===// @@ -298,4 +307,14 @@ def Xsmm_FusedBrgemmDispatchOp : Xsmm_Op<"fused_brgemm.dispatch", [Pure]> { let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// TileConfigOp +//===----------------------------------------------------------------------===// + +def Xsmm_TileConfigOp : Xsmm_Op<"tileConfig", [MemoryEffects<[MemWrite, MemRead]>]> { + let summary = "invoke for tileConfig operation."; + let arguments = (ins I64:$dispatch, Variadic:$inputs); +} + #endif // TPP_XSMM_OPS diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index dc51cfb08..9ad22ddfc 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -470,6 +470,7 @@ def FoldXsmmFlags : Pass<"fold-xsmm-flags", "func::FuncOp"> { let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; } + def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> { let summary = "Tile parallel loops"; let options = [ @@ -494,4 +495,14 @@ def GpuInlineConstants : Pass<"gpu-inline-constants", "func::FuncOp"> { "arith::ArithDialect"]; } +def TileConfigInsertionPass : Pass<"tile-config-insertion-pass", + "func::FuncOp"> { + let summary = "Insert tile configuration xsmm calls"; + let description = [{ + Insert tile configuration xsmm calls and perform LICM on them. + }]; + + let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; +} + #endif // TPP_DIALECT_TPP_PASSES diff --git a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp b/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp index 737ee27e1..292534964 100644 --- a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp +++ b/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp @@ -171,6 +171,19 @@ struct ConvertFusedBrgemmXsmmOp : public OpRewritePattern { } }; +struct ConvertTileConfigXsmmOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TileConfigOp tileConfigOp, + PatternRewriter &rewriter) const override { + std::string funcName = "xsmm_tile_config_invoke"; + buildInvokeCall(rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp, + xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)); + rewriter.eraseOp(tileConfigOp); + return success(); + } +}; + static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, ArrayRef dispatchOperands, ArrayRef dispatchOperandTypes, @@ -227,6 +240,13 @@ void addKindOperand(RewriterBase &rewriter, FusedBrgemmDispatchOp dispatchOp, /* do nothing */ } +void addKindOperand(RewriterBase &rewriter, TileConfigDispatchOp dispatchOp, + SmallVectorImpl &dispatchOperands, + SmallVectorImpl &dispatchOperandTypes) { + /* do nothing */ +} + + static int64_t getOredFlags(ArrayAttr flags) { int64_t oredFlag = 0; for (auto flag : flags) { @@ -370,6 +390,16 @@ struct ConvertUnaryDispatchOp : public OpRewritePattern { } }; +struct ConvertTileConfigDispatchOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TileConfigDispatchOp dispatchOp, + PatternRewriter &rewriter) const override { + return buildDispatchOp(rewriter, dispatchOp, + "xsmm_tile_config_dispatch"); + } +}; + struct ConvertFusedBrgemmOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -395,11 +425,11 @@ struct ConvertXsmmToFunc RewritePatternSet patterns(&getContext()); patterns .add( + ConvertGemmXsmmOp, ConvertBrgemmXsmmOp, ConvertFusedBrgemmXsmmOp, ConvertTileConfigXsmmOp>( patterns.getContext()); patterns.add( + ConvertBrgemmDispatchOp, ConvertFusedBrgemmOp, ConvertTileConfigDispatchOp>( patterns.getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index 65628778b..4aff3fc62 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -144,6 +144,7 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; tilingOptions.tileSizes = parallelTaskGrid; pm.addPass(createSCFParallelLoopTiling(tilingOptions)); + pm.addNestedPass(createTileConfigInsertionPass()); pm.addPass(createConvertSCFToOpenMPPass()); } pm.addPass(createConvertVectorToSCFPass()); diff --git a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp index 1b527743a..06394082a 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp @@ -235,6 +235,21 @@ void BinaryDispatchOp::print(OpAsmPrinter &printer) { printerDataTypeImpl(printer, *this); } +void TileConfigDispatchOp::print(OpAsmPrinter &printer) { + printerInputImpl(printer, *this); + auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; + printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); + printerDataTypeImpl(printer, *this); +} + +ParseResult TileConfigDispatchOp::parse(OpAsmParser &parser, + OperationState &result) { + if (failed(parseInputImpl(parser, result)) || + failed(parserFlagsImpl(parser, result, FLAGS_NAME))) + return failure(); + return parseDataTypeImpl(parser, result); +} + template static LogicalResult verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op, diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index f3644e3ec..8b93e6041 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(TPPTransforms TransformUtils.cpp CombineXsmmPass.cpp SCFParallelLoopTiling.cpp + TileConfig.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/TileConfig.cpp b/lib/TPP/Transforms/TileConfig.cpp new file mode 100644 index 000000000..f94e34dc6 --- /dev/null +++ b/lib/TPP/Transforms/TileConfig.cpp @@ -0,0 +1,154 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop tiling on parallel loops. +// +//===----------------------------------------------------------------------===// +#include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "TPP/Dialect/Xsmm/XsmmUtils.h" +#include "TPP/Transforms/Utils/VNNIUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_TILECONFIGINSERTIONPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; + +namespace mlir { +namespace tpp { + +template +static void appendBrgemmFlags(SmallVector &attributes, + PatternRewriter &rewriter, InvokeOpTy opTy) { + auto flags = + dyn_cast(opTy.getOperand(0).getDefiningOp()).getFlags(); + for (auto flagItr : flags) { + if (flagItr == xsmm::GemmFlagsAttr::get(rewriter.getContext(), + xsmm::GemmFlags::NONE)) { + return; + } + attributes.push_back(flagItr); + } + + if (attributes.empty()) + attributes.push_back( + xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE)); +} + +template +struct TileConfig : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InvokeOpTy op, + PatternRewriter &rewriter) const override { + if (xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()) != + xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)) + return failure(); + auto flags = + dyn_cast(op.getOperand(0).getDefiningOp()).getFlags(); + for (auto flagItr : flags) { + if (flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG) || + flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) { + return failure(); + } + } + + SmallVector attributesSetup; + attributesSetup.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG)); + appendBrgemmFlags(attributesSetup, rewriter, op); + auto tileConfigSetup = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + DenseI64ArrayAttr::get( + rewriter.getContext(), + dyn_cast(op.getOperand(0).getDefiningOp()) + .getInputs()), + rewriter.getArrayAttr(attributesSetup), + xsmm::utils::getDataType(rewriter, op.getOperand(1).getType())); + + SmallVector attributesReset; + attributesReset.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG)); + appendBrgemmFlags(attributesReset, rewriter, op); + auto tileConfigReset = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + DenseI64ArrayAttr::get( + rewriter.getContext(), + dyn_cast(op.getOperand(0).getDefiningOp()) + .getInputs()), + rewriter.getArrayAttr(attributesReset), + xsmm::utils::getDataType(rewriter, op.getOperand(1).getType())); + + SmallVector attributesBrgemm; + attributesBrgemm.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG)); + attributesBrgemm.push_back(xsmm::GemmFlagsAttr::get( + rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG)); + appendBrgemmFlags(attributesBrgemm, rewriter, op); + + auto dispatch = dyn_cast( + rewriter.clone(*op.getOperand(0).getDefiningOp())); + dispatch.setFlagsAttr(rewriter.getArrayAttr(attributesBrgemm)); + + auto alloca = rewriter.create( + op.getLoc(), MemRefType::get({64}, rewriter.getI8Type())); + + ValueRange tileConfigInputs{alloca}; + rewriter.create( + op.getLoc(), tileConfigSetup, tileConfigInputs); + + SmallVector invokeOperands; + invokeOperands.push_back(dispatch); + auto opItr = op->getOperands().begin(); + std::advance(opItr, 1); + invokeOperands.append(opItr, op->getOperands().end()); + rewriter.create( + op.getLoc(), + xsmm::utils::getDataType(rewriter, op.getOperand(1).getType()), + invokeOperands); + + ValueRange tileResetInputs{alloca}; + rewriter.create( + op.getLoc(), tileConfigReset, tileResetInputs); + + //rewriter.create(op.getLoc(), alloca); + rewriter.eraseOp(op); + rewriter.eraseOp(op.getOperand(0).getDefiningOp()); + return success(); + } +}; + +struct TileConfigInsertionPass + : public impl::TileConfigInsertionPassBase { + void populateCombinePatterns(RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateCombinePatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace tpp +} // namespace mlir diff --git a/runtime/Xsmm/XsmmRunnerUtils.cpp b/runtime/Xsmm/XsmmRunnerUtils.cpp index b1ba23eb4..271c45396 100644 --- a/runtime/Xsmm/XsmmRunnerUtils.cpp +++ b/runtime/Xsmm/XsmmRunnerUtils.cpp @@ -14,6 +14,7 @@ #include "XsmmRunnerUtils.h" #include "libxsmm.h" // NOLINT [build/include_subdir] +#include "libxsmm_utils.h" // Helper function prototypes. static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, @@ -209,6 +210,43 @@ xsmm_binary_dispatch(const libxsmm_meltw_binary_type op_type, return reinterpret_cast(kernel); } +extern "C" int64_t xsmm_tile_config_dispatch(const libxsmm_datatype dtype, + int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, + int64_t ldc, int64_t stride_a, + int64_t stride_b, + const libxsmm_gemm_flags flags) { + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_cfg_flags = flags; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + + auto sgemm = libxsmm_dispatch_tilecfg_gemm(l_shape, l_cfg_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate tileconfig func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + extern "C" void xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, void *alignedPtrIn, int64_t offsetIn, void *alignedPtrOut, int64_t offsetOut) { @@ -311,8 +349,8 @@ extern "C" int64_t xsmm_brgemm_dispatch(const libxsmm_datatype dtype, int64_t m, l_brconfig.br_stride_b_hint = stride_a * typeSize; l_brconfig.br_unroll_hint = 0; - auto sgemm = libxsmm_dispatch_brgemm(l_shape, l_flags, l_prefetch_flags, - l_brconfig); + auto sgemm = + libxsmm_dispatch_brgemm(l_shape, l_flags, l_prefetch_flags, l_brconfig); if (!sgemm) { fprintf(stderr, "failed to generate brgemm func\n"); fprintf(stderr, "dtype: %u\n", dtype); @@ -407,8 +445,8 @@ xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, l_postops.d_binary_type = binary_op_type; l_postops.ldd = ldc; - auto sgemm = libxsmm_dispatch_brgemm_ext( - l_shape, l_flags, l_prefetch_flags, l_brconfig, l_argops, l_postops); + auto sgemm = libxsmm_dispatch_brgemm_ext(l_shape, l_flags, l_prefetch_flags, + l_brconfig, l_argops, l_postops); if (!sgemm) { fprintf(stderr, "failed to generate fused brgemm func\n"); fprintf(stderr, "data_type: %u\n", data_type); @@ -420,6 +458,18 @@ xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, return reinterpret_cast(sgemm); } +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *tileState, int64_t offset) { + libxsmm_xmmfunction cfg_tr; + + libxsmm_tilecfg_state *l_tilestate = + reinterpret_cast(tileState); + + cfg_tr.tilecfg = reinterpret_cast(addr); + cfg_tr.tilecfg(l_tilestate); +} + static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, FILE *outfile) { fprintf(outfile, "M: %d\n", gemmShape.m); diff --git a/runtime/Xsmm/XsmmRunnerUtils.h b/runtime/Xsmm/XsmmRunnerUtils.h index 329b0a7c6..6ca412553 100644 --- a/runtime/Xsmm/XsmmRunnerUtils.h +++ b/runtime/Xsmm/XsmmRunnerUtils.h @@ -44,6 +44,10 @@ extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_fused_brgemm_dispatch( const libxsmm_meltw_binary_flags binary_flags, const libxsmm_meltw_binary_type binary_op_type); +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_tile_config_dispatch( + const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_gemm_invoke(const libxsmm_datatype dType, int64_t addr, void *alignedPtrA, int64_t offsetA, void *alignedPtrB, int64_t offsetB, @@ -74,4 +78,8 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_fused_brgemm_invoke( int64_t offsetA, void *alignedPtrB, int64_t offsetB, void *alignedPtrC, int64_t offsetC, void *alignedPtrD, int64_t offsetD, int64_t numBatches); +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offset); + #endif // TPP_EXECUTIONENGINE_CRUNNERUTILS_H From abab8e89f1b0dcdc1e20894d1290369feabfb8d9 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Tue, 27 Feb 2024 06:44:00 -0800 Subject: [PATCH 02/10] Tile config hoisting pass --- include/TPP/Passes.td | 12 ++- lib/TPP/Transforms/CMakeLists.txt | 1 + lib/TPP/Transforms/TileConfig.cpp | 19 ++-- lib/TPP/Transforms/TileConfigHoisting.cpp | 101 ++++++++++++++++++++++ 4 files changed, 122 insertions(+), 11 deletions(-) create mode 100644 lib/TPP/Transforms/TileConfigHoisting.cpp diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 9ad22ddfc..dab4f9a40 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -499,7 +499,17 @@ def TileConfigInsertionPass : Pass<"tile-config-insertion-pass", "func::FuncOp"> { let summary = "Insert tile configuration xsmm calls"; let description = [{ - Insert tile configuration xsmm calls and perform LICM on them. + Insert tile configuration xsmm calls. + }]; + + let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; +} + +def TileConfigHoistingPass : Pass<"tile-config-hoisting-pass", + "func::FuncOp"> { + let summary = "Hoist tile configuration invoke xsmm calls"; + let description = [{ + Run LICM on Tile configuration invoke calls. }]; let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 8b93e6041..3fbfde76a 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -18,6 +18,7 @@ add_mlir_library(TPPTransforms CombineXsmmPass.cpp SCFParallelLoopTiling.cpp TileConfig.cpp + TileConfigHoisting.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/TileConfig.cpp b/lib/TPP/Transforms/TileConfig.cpp index f94e34dc6..e314cb73c 100644 --- a/lib/TPP/Transforms/TileConfig.cpp +++ b/lib/TPP/Transforms/TileConfig.cpp @@ -1,3 +1,4 @@ +//===- TileConfig.cpp -----------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -5,19 +6,17 @@ // //===----------------------------------------------------------------------===// // -// This file implements loop tiling on parallel loops. +// This file inserts tile configuration calls. // //===----------------------------------------------------------------------===// #include "TPP/Dialect/Xsmm/XsmmOps.h" #include "TPP/Dialect/Xsmm/XsmmUtils.h" -#include "TPP/Transforms/Utils/VNNIUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include namespace mlir { namespace tpp { #define GEN_PASS_DEF_TILECONFIGINSERTIONPASS @@ -109,11 +108,11 @@ struct TileConfig : OpRewritePattern { auto alloca = rewriter.create( op.getLoc(), MemRefType::get({64}, rewriter.getI8Type())); - + ValueRange tileConfigInputs{alloca}; - rewriter.create( - op.getLoc(), tileConfigSetup, tileConfigInputs); - + rewriter.create(op.getLoc(), tileConfigSetup, + tileConfigInputs); + SmallVector invokeOperands; invokeOperands.push_back(dispatch); auto opItr = op->getOperands().begin(); @@ -125,10 +124,10 @@ struct TileConfig : OpRewritePattern { invokeOperands); ValueRange tileResetInputs{alloca}; - rewriter.create( - op.getLoc(), tileConfigReset, tileResetInputs); + rewriter.create(op.getLoc(), tileConfigReset, + tileResetInputs); - //rewriter.create(op.getLoc(), alloca); + // rewriter.create(op.getLoc(), alloca); rewriter.eraseOp(op); rewriter.eraseOp(op.getOperand(0).getDefiningOp()); return success(); diff --git a/lib/TPP/Transforms/TileConfigHoisting.cpp b/lib/TPP/Transforms/TileConfigHoisting.cpp new file mode 100644 index 000000000..1bf1deb50 --- /dev/null +++ b/lib/TPP/Transforms/TileConfigHoisting.cpp @@ -0,0 +1,101 @@ +//===- TileConfigHoisting.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements tile configuration hoisting on parallel loops. +// +//===----------------------------------------------------------------------===// +#include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "TPP/Dialect/Xsmm/XsmmUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_TILECONFIGHOISTINGPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; +using namespace mlir::xsmm; + +namespace mlir { +namespace tpp { + +struct TileConfigHoisting : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocaOp alloca, + PatternRewriter &rewriter) const override { + + xsmm::TileConfigOp firstTileConfig, secondTileConfig; + for (auto *user : alloca->getUsers()) { + if (!dyn_cast(user)) { + return failure(); + } + auto flags = + dyn_cast( + dyn_cast(user).getOperand(0).getDefiningOp()) + .getFlags(); + for (auto flagItr : flags) { + if (flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG)) { + firstTileConfig = dyn_cast(user); + + } else if (flagItr == xsmm::GemmFlagsAttr::get( + rewriter.getContext(), + mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) { + secondTileConfig = dyn_cast(user); + } + } + } + + scf::ParallelOp parallelOpParent = NULL; + auto op = alloca.getOperation(); + while (true) { + if (op->getParentOfType()) { + if (&op->getParentOfType().getRegion() == + alloca->getParentRegion()) { + return failure(); + } + parallelOpParent = op->getParentOfType(); + break; + } + op = op->getParentOp(); + } + + if (parallelOpParent == NULL) + return failure(); + + rewriter.moveOpBefore(alloca, parallelOpParent.getBody(), + parallelOpParent.getBody()->begin()); + rewriter.moveOpAfter(firstTileConfig, alloca); + rewriter.moveOpBefore(secondTileConfig, parallelOpParent.getBody(), + std::prev(parallelOpParent.getBody()->end(), 1)); + return success(); + } +}; + +struct TileConfigHoistingPass + : public impl::TileConfigHoistingPassBase { + void populateCombinePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateCombinePatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace tpp +} // namespace mlir From d077e7ddf27ee206d8e03e6d4cbe532426b49a37 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Wed, 28 Feb 2024 00:07:28 -0800 Subject: [PATCH 03/10] Tile config changed to intel amx tile config --- include/TPP/Dialect/Xsmm/XsmmOps.td | 12 ++--- include/TPP/Passes.td | 12 ++--- .../ConvertXsmmToFunc/ConvertXsmmToFunc.cpp | 49 ++++++++++--------- lib/TPP/Dialect/Xsmm/XsmmOps.cpp | 12 ++--- lib/TPP/Transforms/CMakeLists.txt | 4 +- ...{TileConfig.cpp => IntelAMXTileConfig.cpp} | 40 +++++++-------- ...ing.cpp => IntelAMXTileConfigHoisting.cpp} | 32 ++++++------ runtime/Xsmm/XsmmRunnerUtils.cpp | 14 +++--- runtime/Xsmm/XsmmRunnerUtils.h | 6 +-- 9 files changed, 89 insertions(+), 92 deletions(-) rename lib/TPP/Transforms/{TileConfig.cpp => IntelAMXTileConfig.cpp} (82%) rename lib/TPP/Transforms/{TileConfigHoisting.cpp => IntelAMXTileConfigHoisting.cpp} (75%) diff --git a/include/TPP/Dialect/Xsmm/XsmmOps.td b/include/TPP/Dialect/Xsmm/XsmmOps.td index 77efa27c5..0edb2da4c 100644 --- a/include/TPP/Dialect/Xsmm/XsmmOps.td +++ b/include/TPP/Dialect/Xsmm/XsmmOps.td @@ -267,11 +267,11 @@ def Xsmm_BrgemmDispatchOp : Xsmm_GemmLikeOp<"brgemm.dispatch"> { } //===----------------------------------------------------------------------===// -// TileConfigDispatchOp +// IntelAMXTileConfigDispatchOp //===----------------------------------------------------------------------===// -def Xsmm_TileConfigDispatchOp : Xsmm_GemmLikeOp<"tileConfig.dispatch"> { - let summary = "dispatch for tileConfig operation."; +def Xsmm_IntelAMXTileConfigDispatchOp : Xsmm_GemmLikeOp<"IntelAMXtileConfig.dispatch"> { + let summary = "dispatch for Intel amx tileConfig operation."; } //===----------------------------------------------------------------------===// @@ -309,11 +309,11 @@ def Xsmm_FusedBrgemmDispatchOp : Xsmm_Op<"fused_brgemm.dispatch", [Pure]> { //===----------------------------------------------------------------------===// -// TileConfigOp +// IntelAMXTileConfigOp //===----------------------------------------------------------------------===// -def Xsmm_TileConfigOp : Xsmm_Op<"tileConfig", [MemoryEffects<[MemWrite, MemRead]>]> { - let summary = "invoke for tileConfig operation."; +def Xsmm_IntelAMXTileConfigOp : Xsmm_Op<"IntelAMXtileConfig", [MemoryEffects<[MemWrite, MemRead]>]> { + let summary = "invoke for Intel AMX tileConfig operation."; let arguments = (ins I64:$dispatch, Variadic:$inputs); } diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index dab4f9a40..f5fb77131 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -495,21 +495,21 @@ def GpuInlineConstants : Pass<"gpu-inline-constants", "func::FuncOp"> { "arith::ArithDialect"]; } -def TileConfigInsertionPass : Pass<"tile-config-insertion-pass", +def IntelAMXTileConfigInsertionPass : Pass<"intel-amx-tile-config-insertion-pass", "func::FuncOp"> { - let summary = "Insert tile configuration xsmm calls"; + let summary = "Insert intel amx tile configuration xsmm calls"; let description = [{ - Insert tile configuration xsmm calls. + Insert intel amx tile configuration xsmm calls. }]; let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; } -def TileConfigHoistingPass : Pass<"tile-config-hoisting-pass", +def IntelAMXTileConfigHoistingPass : Pass<"intel-amx-tile-config-hoisting-pass", "func::FuncOp"> { - let summary = "Hoist tile configuration invoke xsmm calls"; + let summary = "Hoist intel amx tile configuration invoke xsmm calls"; let description = [{ - Run LICM on Tile configuration invoke calls. + Run LICM on intel amx tile configuration invoke calls. }]; let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; diff --git a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp b/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp index 292534964..84183c3f8 100644 --- a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp +++ b/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp @@ -171,14 +171,16 @@ struct ConvertFusedBrgemmXsmmOp : public OpRewritePattern { } }; -struct ConvertTileConfigXsmmOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConvertIntelAMXTileConfigXsmmOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TileConfigOp tileConfigOp, + LogicalResult matchAndRewrite(IntelAMXTileConfigOp tileConfigOp, PatternRewriter &rewriter) const override { - std::string funcName = "xsmm_tile_config_invoke"; - buildInvokeCall(rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp, - xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)); + std::string funcName = "xsmm_intel_amx_tile_config_invoke"; + buildInvokeCall( + rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp, + xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)); rewriter.eraseOp(tileConfigOp); return success(); } @@ -208,10 +210,9 @@ static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, return call; } -template ::value || - std::is_same::value>> +template ::value || + std::is_same::value>> void addKindOperand(RewriterBase &rewriter, OpTy dispatchOp, SmallVectorImpl &dispatchOperands, SmallVectorImpl &dispatchOperandTypes) { @@ -240,13 +241,13 @@ void addKindOperand(RewriterBase &rewriter, FusedBrgemmDispatchOp dispatchOp, /* do nothing */ } -void addKindOperand(RewriterBase &rewriter, TileConfigDispatchOp dispatchOp, +void addKindOperand(RewriterBase &rewriter, + IntelAMXTileConfigDispatchOp dispatchOp, SmallVectorImpl &dispatchOperands, SmallVectorImpl &dispatchOperandTypes) { /* do nothing */ } - static int64_t getOredFlags(ArrayAttr flags) { int64_t oredFlag = 0; for (auto flag : flags) { @@ -390,13 +391,14 @@ struct ConvertUnaryDispatchOp : public OpRewritePattern { } }; -struct ConvertTileConfigDispatchOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConvertIntelAMXTileConfigDispatchOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TileConfigDispatchOp dispatchOp, + LogicalResult matchAndRewrite(IntelAMXTileConfigDispatchOp dispatchOp, PatternRewriter &rewriter) const override { - return buildDispatchOp(rewriter, dispatchOp, - "xsmm_tile_config_dispatch"); + return buildDispatchOp( + rewriter, dispatchOp, "xsmm_intel_amx_tile_config_dispatch"); } }; @@ -423,13 +425,12 @@ struct ConvertXsmmToFunc : public tpp::impl::ConvertXsmmToFuncBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns - .add( - patterns.getContext()); - patterns.add( + patterns.add(patterns.getContext()); + patterns.add( patterns.getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp index 06394082a..139c13219 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp @@ -179,7 +179,7 @@ static void printerDataTypeImpl(OpAsmPrinter &printer, OpTy op) { template static void printerFlagsImpl(OpAsmPrinter &printer, - const std::function& fn, + const std::function &fn, const std::string_view &flagsName) { printer << " " << flagsName << " = ("; llvm::interleaveComma(fn(), printer, [&](auto &flag) { @@ -235,15 +235,15 @@ void BinaryDispatchOp::print(OpAsmPrinter &printer) { printerDataTypeImpl(printer, *this); } -void TileConfigDispatchOp::print(OpAsmPrinter &printer) { - printerInputImpl(printer, *this); +void IntelAMXTileConfigDispatchOp::print(OpAsmPrinter &printer) { + printerInputImpl(printer, *this); auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); - printerDataTypeImpl(printer, *this); + printerDataTypeImpl(printer, *this); } -ParseResult TileConfigDispatchOp::parse(OpAsmParser &parser, - OperationState &result) { +ParseResult IntelAMXTileConfigDispatchOp::parse(OpAsmParser &parser, + OperationState &result) { if (failed(parseInputImpl(parser, result)) || failed(parserFlagsImpl(parser, result, FLAGS_NAME))) return failure(); diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 3fbfde76a..53d60a390 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -17,8 +17,8 @@ add_mlir_library(TPPTransforms TransformUtils.cpp CombineXsmmPass.cpp SCFParallelLoopTiling.cpp - TileConfig.cpp - TileConfigHoisting.cpp + IntelAMXTileConfig.cpp + IntelAMXTileConfigHoisting.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/TileConfig.cpp b/lib/TPP/Transforms/IntelAMXTileConfig.cpp similarity index 82% rename from lib/TPP/Transforms/TileConfig.cpp rename to lib/TPP/Transforms/IntelAMXTileConfig.cpp index e314cb73c..3a3c19bb3 100644 --- a/lib/TPP/Transforms/TileConfig.cpp +++ b/lib/TPP/Transforms/IntelAMXTileConfig.cpp @@ -1,4 +1,4 @@ -//===- TileConfig.cpp -----------------------------------------------------===// +//===- IntelAMXTileConfig.cpp ---------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -19,7 +19,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace tpp { -#define GEN_PASS_DEF_TILECONFIGINSERTIONPASS +#define GEN_PASS_DEF_INTELAMXTILECONFIGINSERTIONPASS #include "TPP/Passes.h.inc" } // namespace tpp } // namespace mlir @@ -35,10 +35,9 @@ static void appendBrgemmFlags(SmallVector &attributes, auto flags = dyn_cast(opTy.getOperand(0).getDefiningOp()).getFlags(); for (auto flagItr : flags) { - if (flagItr == xsmm::GemmFlagsAttr::get(rewriter.getContext(), - xsmm::GemmFlags::NONE)) { + if (flagItr == + xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE)) return; - } attributes.push_back(flagItr); } @@ -48,7 +47,7 @@ static void appendBrgemmFlags(SmallVector &attributes, } template -struct TileConfig : OpRewritePattern { +struct IntelAMXTileConfig : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InvokeOpTy op, @@ -58,22 +57,20 @@ struct TileConfig : OpRewritePattern { return failure(); auto flags = dyn_cast(op.getOperand(0).getDefiningOp()).getFlags(); - for (auto flagItr : flags) { + for (auto flagItr : flags) if (flagItr == xsmm::GemmFlagsAttr::get( rewriter.getContext(), mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG) || flagItr == xsmm::GemmFlagsAttr::get( rewriter.getContext(), - mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) { + mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) return failure(); - } - } SmallVector attributesSetup; attributesSetup.push_back(xsmm::GemmFlagsAttr::get( rewriter.getContext(), xsmm::GemmFlags::NO_RESET_TILECONFIG)); appendBrgemmFlags(attributesSetup, rewriter, op); - auto tileConfigSetup = rewriter.create( + auto tileConfigSetup = rewriter.create( op.getLoc(), rewriter.getI64Type(), DenseI64ArrayAttr::get( rewriter.getContext(), @@ -86,7 +83,7 @@ struct TileConfig : OpRewritePattern { attributesReset.push_back(xsmm::GemmFlagsAttr::get( rewriter.getContext(), xsmm::GemmFlags::NO_SETUP_TILECONFIG)); appendBrgemmFlags(attributesReset, rewriter, op); - auto tileConfigReset = rewriter.create( + auto tileConfigReset = rewriter.create( op.getLoc(), rewriter.getI64Type(), DenseI64ArrayAttr::get( rewriter.getContext(), @@ -110,8 +107,8 @@ struct TileConfig : OpRewritePattern { op.getLoc(), MemRefType::get({64}, rewriter.getI8Type())); ValueRange tileConfigInputs{alloca}; - rewriter.create(op.getLoc(), tileConfigSetup, - tileConfigInputs); + rewriter.create( + op.getLoc(), tileConfigSetup, tileConfigInputs); SmallVector invokeOperands; invokeOperands.push_back(dispatch); @@ -124,22 +121,23 @@ struct TileConfig : OpRewritePattern { invokeOperands); ValueRange tileResetInputs{alloca}; - rewriter.create(op.getLoc(), tileConfigReset, - tileResetInputs); + rewriter.create( + op.getLoc(), tileConfigReset, tileResetInputs); - // rewriter.create(op.getLoc(), alloca); rewriter.eraseOp(op); rewriter.eraseOp(op.getOperand(0).getDefiningOp()); return success(); } }; -struct TileConfigInsertionPass - : public impl::TileConfigInsertionPassBase { +struct IntelAMXTileConfigInsertionPass + : public impl::IntelAMXTileConfigInsertionPassBase< + IntelAMXTileConfigInsertionPass> { void populateCombinePatterns(RewritePatternSet &patterns) { - patterns.add>( + patterns.add>( patterns.getContext()); - patterns.add>( + patterns.add< + IntelAMXTileConfig>( patterns.getContext()); } diff --git a/lib/TPP/Transforms/TileConfigHoisting.cpp b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp similarity index 75% rename from lib/TPP/Transforms/TileConfigHoisting.cpp rename to lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp index 1bf1deb50..81d0bd9b3 100644 --- a/lib/TPP/Transforms/TileConfigHoisting.cpp +++ b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp @@ -1,4 +1,4 @@ -//===- TileConfigHoisting.cpp ---------------------------------------------===// +//===- IntelAMXTileConfigHoisting.cpp -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,16 +10,14 @@ // //===----------------------------------------------------------------------===// #include "TPP/Dialect/Xsmm/XsmmOps.h" -#include "TPP/Dialect/Xsmm/XsmmUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace tpp { -#define GEN_PASS_DEF_TILECONFIGHOISTINGPASS +#define GEN_PASS_DEF_INTELAMXTILECONFIGHOISTINGPASS #include "TPP/Passes.h.inc" } // namespace tpp } // namespace mlir @@ -30,31 +28,32 @@ using namespace mlir::xsmm; namespace mlir { namespace tpp { -struct TileConfigHoisting : OpRewritePattern { +struct IntelAMXTileConfigHoisting : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::AllocaOp alloca, PatternRewriter &rewriter) const override { - xsmm::TileConfigOp firstTileConfig, secondTileConfig; + xsmm::IntelAMXTileConfigOp firstTileConfig, secondTileConfig; for (auto *user : alloca->getUsers()) { - if (!dyn_cast(user)) { + if (!dyn_cast(user)) { return failure(); } - auto flags = - dyn_cast( - dyn_cast(user).getOperand(0).getDefiningOp()) - .getFlags(); + auto flags = dyn_cast( + dyn_cast(user) + .getOperand(0) + .getDefiningOp()) + .getFlags(); for (auto flagItr : flags) { if (flagItr == xsmm::GemmFlagsAttr::get( rewriter.getContext(), mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG)) { - firstTileConfig = dyn_cast(user); + firstTileConfig = dyn_cast(user); } else if (flagItr == xsmm::GemmFlagsAttr::get( rewriter.getContext(), mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) { - secondTileConfig = dyn_cast(user); + secondTileConfig = dyn_cast(user); } } } @@ -85,10 +84,11 @@ struct TileConfigHoisting : OpRewritePattern { } }; -struct TileConfigHoistingPass - : public impl::TileConfigHoistingPassBase { +struct IntelAMXTileConfigHoistingPass + : public impl::IntelAMXTileConfigHoistingPassBase< + IntelAMXTileConfigHoistingPass> { void populateCombinePatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } void runOnOperation() override { diff --git a/runtime/Xsmm/XsmmRunnerUtils.cpp b/runtime/Xsmm/XsmmRunnerUtils.cpp index 271c45396..62b784f1f 100644 --- a/runtime/Xsmm/XsmmRunnerUtils.cpp +++ b/runtime/Xsmm/XsmmRunnerUtils.cpp @@ -210,12 +210,10 @@ xsmm_binary_dispatch(const libxsmm_meltw_binary_type op_type, return reinterpret_cast(kernel); } -extern "C" int64_t xsmm_tile_config_dispatch(const libxsmm_datatype dtype, - int64_t m, int64_t n, int64_t k, - int64_t lda, int64_t ldb, - int64_t ldc, int64_t stride_a, - int64_t stride_b, - const libxsmm_gemm_flags flags) { +extern "C" int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype dtype, int64_t m, int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { libxsmm_blasint m_int = m; libxsmm_blasint n_int = n; libxsmm_blasint k_int = k; @@ -459,8 +457,8 @@ xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, } extern "C" MLIR_RUNNERUTILS_EXPORT void -xsmm_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, - void *tileState, int64_t offset) { +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *tileState, int64_t offset) { libxsmm_xmmfunction cfg_tr; libxsmm_tilecfg_state *l_tilestate = diff --git a/runtime/Xsmm/XsmmRunnerUtils.h b/runtime/Xsmm/XsmmRunnerUtils.h index 6ca412553..3d5048188 100644 --- a/runtime/Xsmm/XsmmRunnerUtils.h +++ b/runtime/Xsmm/XsmmRunnerUtils.h @@ -44,7 +44,7 @@ extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_fused_brgemm_dispatch( const libxsmm_meltw_binary_flags binary_flags, const libxsmm_meltw_binary_type binary_op_type); -extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_tile_config_dispatch( +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_intel_amx_tile_config_dispatch( const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, const libxsmm_gemm_flags); @@ -79,7 +79,7 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_fused_brgemm_invoke( int64_t offsetC, void *alignedPtrD, int64_t offsetD, int64_t numBatches); extern "C" MLIR_RUNNERUTILS_EXPORT void -xsmm_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, - void *alignedPtrA, int64_t offset); +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offset); #endif // TPP_EXECUTIONENGINE_CRUNNERUTILS_H From eb7360bc1e215ecc0f32187cdc2ec2c7df1678f8 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Wed, 28 Feb 2024 01:16:14 -0800 Subject: [PATCH 04/10] Build fix --- lib/TPP/DefaultPipeline.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index 4aff3fc62..5d8527d8c 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -144,7 +144,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; tilingOptions.tileSizes = parallelTaskGrid; pm.addPass(createSCFParallelLoopTiling(tilingOptions)); - pm.addNestedPass(createTileConfigInsertionPass()); + pm.addNestedPass(createIntelAMXTileConfigInsertionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createIntelAMXTileConfigHoistingPass()); pm.addPass(createConvertSCFToOpenMPPass()); } pm.addPass(createConvertVectorToSCFPass()); From a6134d64fa545e216487ed21ef13b2149035b2b2 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Thu, 29 Feb 2024 20:08:11 -0800 Subject: [PATCH 05/10] Fixed pass pipeline and some minor edits --- include/TPP/Passes.td | 13 +++++++++++-- lib/TPP/DefaultPipeline.cpp | 16 +++++----------- lib/TPP/DefaultTppPasses.cpp | 17 ++++++++++++++++- lib/TPP/Transforms/IntelAMXTileConfig.cpp | 2 +- .../Transforms/IntelAMXTileConfigHoisting.cpp | 2 +- lib/TPP/Transforms/SCFParallelLoopTiling.cpp | 4 ++-- 6 files changed, 36 insertions(+), 18 deletions(-) diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index f5fb77131..8df063e6b 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -204,7 +204,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> { let options= [ Option<"linalgToLoops", "linalg-to-loops", "bool", /*default=*/"false", - "Skip all TPP transformations. Lower linalg directly to loops."> + "Skip all TPP transformations. Lower linalg directly to loops.">, + ListOption<"parallelTaskGrid", "parallel-task-grid", + "unsigned", "Grid-sizes for parallel tasks."> + ]; } @@ -289,6 +292,12 @@ def LocalDialectsLowering : Pass<"lower-local-dialects", "ModuleOp"> { "tensor::TensorDialect", "xsmm::XsmmDialect", "LLVM::LLVMDialect"]; + let options = [ + ListOption<"parallelTaskGrid", "parallel-task-grid", + "unsigned", "Grid-sizes for parallel tasks."> + + ]; + } def Postprocessing : Pass<"postprocess", "func::FuncOp"> { @@ -474,7 +483,7 @@ def FoldXsmmFlags : Pass<"fold-xsmm-flags", "func::FuncOp"> { def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> { let summary = "Tile parallel loops"; let options = [ - ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t", + ListOption<"tileSizes", "parallel-loop-tile-sizes", "unsigned", "Factors to tile parallel loops by">, Option<"noMinMaxBounds", "no-min-max-bounds", "bool", /*default=*/"false", diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index 5d8527d8c..b5c20285c 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -47,10 +47,10 @@ llvm::cl::opt llvm::cl::init(false)); // Control grid parallelism sizes. -llvm::cl::list +llvm::cl::list parallelTaskGrid("parallel-task-grid", llvm::cl::desc("Grid-sizes for parallel tasks"), - llvm::cl::list_init(SmallVector{2, 8}), + llvm::cl::list_init(SmallVector{2, 8}), llvm::cl::CommaSeparated); namespace mlir { @@ -127,7 +127,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend})); } else { // Apply the default preprocessing pass - DefaultTppPassesOptions tppDefaultOptions{linalgToLoops}; + DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, + parallelTaskGrid}; pm.addPass(createDefaultTppPasses(tppDefaultOptions)); } @@ -140,15 +141,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(tpp::createConvertPerfToFunc()); pm.addPass(createConvertTensorToLinalgPass()); pm.addNestedPass(createConvertLinalgToLoopsPass()); - if (defParallel) { - mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; - tilingOptions.tileSizes = parallelTaskGrid; - pm.addPass(createSCFParallelLoopTiling(tilingOptions)); - pm.addNestedPass(createIntelAMXTileConfigInsertionPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass(createIntelAMXTileConfigHoistingPass()); + if (defParallel) pm.addPass(createConvertSCFToOpenMPPass()); - } pm.addPass(createConvertVectorToSCFPass()); pm.addPass(arith::createArithExpandOpsPass()); pm.addPass(createLowerAffinePass()); diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index b7905cf24..fec85685d 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -77,6 +77,10 @@ struct LocalDialectsLowering : public tpp::impl::LocalDialectsLoweringBase, UtilityPassBase { + LocalDialectsLowering() {} + LocalDialectsLowering(const LocalDialectsLoweringOptions &options) { + parallelTaskGrid = options.parallelTaskGrid; + } void runOnOperation() override { auto module = getOperation(); @@ -106,6 +110,16 @@ struct LocalDialectsLowering // that they are hoisted out of loops. pm.addNestedPass(createCleanup()); + mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; + tilingOptions.tileSizes = parallelTaskGrid; + pm.addPass(createSCFParallelLoopTiling(tilingOptions)); + + pm.addNestedPass(createIntelAMXTileConfigInsertionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createLoopInvariantCodeMotionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createIntelAMXTileConfigHoistingPass()); + pm.addPass(createConvertXsmmToFunc()); pm.addPass(createConvertPerfToFunc()); } @@ -310,7 +324,8 @@ struct DefaultTppPasses pm.addPass(createConvertForAllToParallelOp()); // Covert all local TPP-related dialects. - pm.addPass(createLocalDialectsLowering()); + LocalDialectsLoweringOptions localDialectsLowering{parallelTaskGrid}; + pm.addPass(createLocalDialectsLowering(localDialectsLowering)); // Clean up after the default pipeline. pm.addNestedPass(createPostprocessing()); diff --git a/lib/TPP/Transforms/IntelAMXTileConfig.cpp b/lib/TPP/Transforms/IntelAMXTileConfig.cpp index 3a3c19bb3..be75be93e 100644 --- a/lib/TPP/Transforms/IntelAMXTileConfig.cpp +++ b/lib/TPP/Transforms/IntelAMXTileConfig.cpp @@ -1,4 +1,4 @@ -//===- IntelAMXTileConfig.cpp ---------------------------------------------===// +//===- IntelAMXTileConfig.cpp ------------------------------------*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp index 81d0bd9b3..9c222a997 100644 --- a/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp +++ b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp @@ -1,4 +1,4 @@ -//===- IntelAMXTileConfigHoisting.cpp -------------------------------------===// +//===- IntelAMXTileConfigHoisting.cpp ----------------------------*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/TPP/Transforms/SCFParallelLoopTiling.cpp b/lib/TPP/Transforms/SCFParallelLoopTiling.cpp index f6632ed1a..8379e1293 100644 --- a/lib/TPP/Transforms/SCFParallelLoopTiling.cpp +++ b/lib/TPP/Transforms/SCFParallelLoopTiling.cpp @@ -54,7 +54,7 @@ using namespace mlir::scf; /// %i0 + j0 and %i1 + %j1. /// /// The old loop is replaced with the new one. -void tileParallelLoop(ParallelOp op, ArrayRef tileSizes, +void tileParallelLoop(ParallelOp op, ArrayRef tileSizes, bool noMinMaxBounds) { bool useParallelOp = false; /* TODO, need to implement this case */ @@ -244,7 +244,7 @@ namespace { struct SCFParallelLoopTiling : public tpp::impl::SCFParallelLoopTilingBase { SCFParallelLoopTiling(){}; - SCFParallelLoopTiling(ArrayRef tileSizes, + SCFParallelLoopTiling(ArrayRef tileSizes, bool noMinMaxBounds = false) { this->tileSizes = tileSizes; this->noMinMaxBounds = noMinMaxBounds; From 97a8beed4ea67f2f00d0e2ae1ff91156055786f3 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Thu, 29 Feb 2024 21:42:19 -0800 Subject: [PATCH 06/10] Minor fix --- lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp index 9c222a997..df0ba42c1 100644 --- a/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp +++ b/lib/TPP/Transforms/IntelAMXTileConfigHoisting.cpp @@ -60,7 +60,7 @@ struct IntelAMXTileConfigHoisting : OpRewritePattern { scf::ParallelOp parallelOpParent = NULL; auto op = alloca.getOperation(); - while (true) { + while (op) { if (op->getParentOfType()) { if (&op->getParentOfType().getRegion() == alloca->getParentRegion()) { From 85f2ef430807a1701d1205584f17f80a5a5960e7 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Thu, 29 Feb 2024 22:45:26 -0800 Subject: [PATCH 07/10] Fix for benchmarks from Alex --- lib/TPP/Transforms/IntelAMXTileConfig.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/TPP/Transforms/IntelAMXTileConfig.cpp b/lib/TPP/Transforms/IntelAMXTileConfig.cpp index be75be93e..0f90a1446 100644 --- a/lib/TPP/Transforms/IntelAMXTileConfig.cpp +++ b/lib/TPP/Transforms/IntelAMXTileConfig.cpp @@ -125,7 +125,8 @@ struct IntelAMXTileConfig : OpRewritePattern { op.getLoc(), tileConfigReset, tileResetInputs); rewriter.eraseOp(op); - rewriter.eraseOp(op.getOperand(0).getDefiningOp()); + if (op.getOperand(0).getDefiningOp()->getUsers().empty()) + rewriter.eraseOp(op.getOperand(0).getDefiningOp()); return success(); } }; From a23402b426ad34859a699d8a227ebbc53c906fe0 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Thu, 29 Feb 2024 22:52:43 -0800 Subject: [PATCH 08/10] Test fixes for LICM hoisting addis --- test/Passes/pass-convert-gemm-to-parallel-tile.mlir | 2 +- test/Passes/pass-convert-mlp-to-parallel-tile.mlir | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Passes/pass-convert-gemm-to-parallel-tile.mlir b/test/Passes/pass-convert-gemm-to-parallel-tile.mlir index 8efc7ab18..e6f75a897 100644 --- a/test/Passes/pass-convert-gemm-to-parallel-tile.mlir +++ b/test/Passes/pass-convert-gemm-to-parallel-tile.mlir @@ -31,7 +31,7 @@ module { // CHECK: omp.wsloop for (%[[ARG3:.*]], %[[ARG4:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { // CHECK: memref.alloca_scope { // CHECK: scf.for %[[ARG5:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG5]], %[[ARG3]] : index // CHECK: scf.for %[[ARG6:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { -// CHECK: %[[temp1:.*]] = arith.addi %[[ARG5]], %[[ARG3]] : index // CHECK: %[[temp2:.*]] = arith.addi %[[ARG6]], %[[ARG4]] : index diff --git a/test/Passes/pass-convert-mlp-to-parallel-tile.mlir b/test/Passes/pass-convert-mlp-to-parallel-tile.mlir index c941c15cd..2227dc44f 100644 --- a/test/Passes/pass-convert-mlp-to-parallel-tile.mlir +++ b/test/Passes/pass-convert-mlp-to-parallel-tile.mlir @@ -82,22 +82,22 @@ module { //CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) { //CHECK: memref.alloca_scope { //CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] { -//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index //CHECK: omp.parallel { //CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) { //CHECK: memref.alloca_scope { //CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] { -//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index //CHECK: omp.parallel { //CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) { //CHECK: memref.alloca_scope { //CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] { -//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index //CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index From 42578b4e829c83eb2af03749442ef31bdea32259 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Thu, 29 Feb 2024 23:27:46 -0800 Subject: [PATCH 09/10] Tile config insertion pass test case --- test/Passes/pass-tileconfig-insertion.mlir | 113 +++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 test/Passes/pass-tileconfig-insertion.mlir diff --git a/test/Passes/pass-tileconfig-insertion.mlir b/test/Passes/pass-tileconfig-insertion.mlir new file mode 100644 index 000000000..5ccfdd87d --- /dev/null +++ b/test/Passes/pass-tileconfig-insertion.mlir @@ -0,0 +1,113 @@ +// RUN: tpp-opt %s --intel-amx-tile-config-insertion-pass | FileCheck %s + +module { + memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @entry(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c32_i64 = arith.constant 32 : i64 + %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %c0_0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c8_1 = arith.constant 8 : index + %2 = arith.muli %c1, %c2 : index + %3 = arith.muli %c1, %c8_1 : index + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%2, %3) { + scf.for %arg3 = %c0_0 to %2 step %c1 { + scf.for %arg4 = %c0_0 to %3 step %c1 { + %8 = arith.addi %arg3, %arg1 : index + %9 = arith.addi %arg4, %arg2 : index + %subview = memref.subview %alloc[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_9 = memref.subview %arg0[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + } + } + scf.reduce + } + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %c0_3 = arith.constant 0 : index + %c2_4 = arith.constant 2 : index + %c8_5 = arith.constant 8 : index + %4 = arith.muli %c1, %c2_4 : index + %5 = arith.muli %c1, %c8_5 : index + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%4, %5) { + scf.for %arg3 = %c0_3 to %4 step %c1 { + scf.for %arg4 = %c0_3 to %5 step %c1 { + %8 = arith.addi %arg3, %arg1 : index + %9 = arith.addi %arg4, %arg2 : index + %subview = memref.subview %alloc_2[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_9 = memref.subview %alloc[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + } + } + scf.reduce + } + %c0_6 = arith.constant 0 : index + %c2_7 = arith.constant 2 : index + %c8_8 = arith.constant 8 : index + %6 = arith.muli %c1, %c2_7 : index + %7 = arith.muli %c1, %c8_8 : index + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%6, %7) { + scf.for %arg3 = %c0_6 to %6 step %c1 { + scf.for %arg4 = %c0_6 to %7 step %c1 { + %8 = arith.addi %arg3, %arg1 : index + %9 = arith.addi %arg4, %arg2 : index + %subview = memref.subview %alloc[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_9 = memref.subview %alloc_2[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + } + } + scf.reduce + } + memref.dealloc %alloc_2 : memref<8x32x32x32xbf16> + return %alloc : memref<8x32x32x32xbf16> + } +} + +// CHECK:func.func @entry(%[[ARG0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { +// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () From b9166b534a48b0af6150e00a784975dce6120cf4 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Fri, 1 Mar 2024 00:07:38 -0800 Subject: [PATCH 10/10] Tile config hoisting pass test --- .../Passes/pass-tileconfig-hoisting-pass.mlir | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 test/Passes/pass-tileconfig-hoisting-pass.mlir diff --git a/test/Passes/pass-tileconfig-hoisting-pass.mlir b/test/Passes/pass-tileconfig-hoisting-pass.mlir new file mode 100644 index 000000000..5e4786476 --- /dev/null +++ b/test/Passes/pass-tileconfig-hoisting-pass.mlir @@ -0,0 +1,132 @@ +// RUN: tpp-opt %s --intel-amx-tile-config-hoisting-pass | FileCheck %s + +module{ + +memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + +func.func @entry(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c32_i64 = arith.constant 32 : i64 + %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %1 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 + %2 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + %3 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { + scf.for %arg3 = %c0 to %c2 step %c1 { + %10 = arith.addi %arg3, %arg1 : index + %subview = memref.subview %arg0[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + scf.for %arg4 = %c0 to %c8 step %c1 { + %11 = arith.addi %arg4, %arg2 : index + %subview_1 = memref.subview %alloc[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %alloca = memref.alloca() : memref<64xi8> + "xsmm.IntelAMXtileConfig"(%1, %alloca) : (i64, memref<64xi8>) -> () + xsmm.brgemm(data_type = bf16, %3, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + "xsmm.IntelAMXtileConfig"(%2, %alloca) : (i64, memref<64xi8>) -> () + } + } + scf.reduce + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %4 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 + %5 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + %6 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { + scf.for %arg3 = %c0 to %c2 step %c1 { + %10 = arith.addi %arg3, %arg1 : index + %subview = memref.subview %alloc[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + scf.for %arg4 = %c0 to %c8 step %c1 { + %11 = arith.addi %arg4, %arg2 : index + %subview_1 = memref.subview %alloc_0[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %alloca = memref.alloca() : memref<64xi8> + "xsmm.IntelAMXtileConfig"(%4, %alloca) : (i64, memref<64xi8>) -> () + xsmm.brgemm(data_type = bf16, %6, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + "xsmm.IntelAMXtileConfig"(%5, %alloca) : (i64, memref<64xi8>) -> () + } + } + scf.reduce + } + %7 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 + %8 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + %9 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { + scf.for %arg3 = %c0 to %c2 step %c1 { + %10 = arith.addi %arg3, %arg1 : index + %subview = memref.subview %alloc_0[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + scf.for %arg4 = %c0 to %c8 step %c1 { + %11 = arith.addi %arg4, %arg2 : index + %subview_1 = memref.subview %alloc[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %alloca = memref.alloca() : memref<64xi8> + "xsmm.IntelAMXtileConfig"(%7, %alloca) : (i64, memref<64xi8>) -> () + xsmm.brgemm(data_type = bf16, %9, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + "xsmm.IntelAMXtileConfig"(%8, %alloca) : (i64, memref<64xi8>) -> () + } + } + scf.reduce + } + memref.dealloc %alloc_0 : memref<8x32x32x32xbf16> + return %alloc : memref<8x32x32x32xbf16> +} +} + +// CHECK-LABEL: func.func @entry( +// CHECK: %[[ARG0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 +// CHECK: %[[dispatch1:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[dispatch2:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[brgemmdispatch:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch1]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { +// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: } +// CHECK: } +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch2]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.reduce +// CHECK: } +// CHECK: %[[dispatch3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[dispatch4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[brgemmdispatch2:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch3]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { +// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: } +// CHECK: } +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch4]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.reduce +// CHECK: } +// CHECK: %[[dispatch5:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[dispatch6:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: %[[brgemmdispatch3:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 +// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { +// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch5]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { +// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index +// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { +// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index +// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch3]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) +// CHECK: } +// CHECK: } +// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch6]], %[[alloca]]) : (i64, memref<64xi8>) -> () +// CHECK: scf.reduce +// CHECK: } +