From 5d8362ca8acb7896338de9768f57f597230d8a06 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 6 Aug 2024 15:09:35 -0400 Subject: [PATCH] [Codegen][GPU] Add kernel config for LLVMGPUTileAndFuse for targeting mma (#18105) This patch adds kernel configuration logic for contraction like operations to use mma instructions with the LLVMGPUTileAndFusePipeline. This directly leverages the configuration logic already present and in use for existing matmul based pipelines, instead generating a config for TileAndFuse. --- .../Dialect/Codegen/IR/IREECodegenAttrs.h | 26 +- .../Codegen/Dialect/GPU/IR/BUILD.bazel | 4 +- .../Codegen/Dialect/GPU/IR/CMakeLists.txt | 4 +- .../Dialect/GPU/IR/DerivedConfigUtils.cpp | 111 ++++++++ .../Dialect/GPU/IR/DerivedConfigUtils.h | 18 ++ .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 3 +- .../Dialect/GPU/TargetUtils/BUILD.bazel | 3 +- .../Dialect/GPU/TargetUtils/CMakeLists.txt | 3 +- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 257 ++++++++++++------ .../Dialect/GPU/TargetUtils/ConfigUtils.h | 9 +- .../Codegen/LLVMCPU/KernelDispatch.cpp | 2 +- .../iree/compiler/Codegen/LLVMGPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMGPU/CMakeLists.txt | 1 + .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 38 ++- .../Codegen/LLVMGPU/ROCDLKernelConfig.cpp | 22 +- .../Codegen/LLVMGPU/test/ROCDL/BUILD.bazel | 1 + .../Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt | 1 + .../test/ROCDL/config_tile_and_fuse.mlir | 58 ++++ .../compiler/Codegen/SPIRV/KernelConfig.cpp | 4 +- 19 files changed, 463 insertions(+), 103 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h index a26b6b5e26d0..71d54e83496f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h @@ -111,6 +111,24 @@ SmallVector getTileSizes(OpBuilder &b, Operation *op, unsigned level); /// Sets the lowering configuration, overwriting existing attribute values. void setLoweringConfig(Operation *op, Attribute config); +/// Convenience function that sets the lowering configuration on the operation +/// and translation info for a generic lowering config, lowering pipeline, +/// and optional workgroup/subgroup size. +inline LogicalResult setOpConfigAndEntryPointFnTranslation( + mlir::FunctionOpInterface entryPointFn, Operation *op, + IREE::Codegen::LoweringConfigAttrInterface config, + IREE::Codegen::DispatchLoweringPassPipeline passPipeline, + ArrayRef workgroupSize = {}, + std::optional subgroupSize = {}, + DictionaryAttr pipelineConfig = DictionaryAttr()) { + MLIRContext *context = entryPointFn.getContext(); + setLoweringConfig(op, config); + auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( + context, passPipeline, SymbolRefAttr(), workgroupSize, subgroupSize, + pipelineConfig); + return setTranslationInfo(entryPointFn, translationInfo); +} + /// Convenience function that sets the lowering configuration on the operation /// and translation info on the entry point op for the common case of specifying /// tile sizes to use for the operation, and pass pipeline to use for the @@ -126,11 +144,9 @@ inline LogicalResult setOpConfigAndEntryPointFnTranslation( MLIRContext *context = entryPointFn.getContext(); auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes, scalableTileFlags); - setLoweringConfig(op, config); - auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( - entryPointFn.getContext(), passPipeline, SymbolRefAttr(), workgroupSize, - subgroupSize, pipelineConfig); - return setTranslationInfo(entryPointFn, translationInfo); + return setOpConfigAndEntryPointFnTranslation(entryPointFn, op, config, + passPipeline, workgroupSize, + subgroupSize, pipelineConfig); } /// Overload of setOpConfigAndEntryPointFnTranslation() for the "no scalable diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel index 9e12352c989a..b71a512e7927 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel @@ -46,12 +46,14 @@ iree_td_library( iree_compiler_cc_library( name = "IREEGPUDialect", srcs = [ + "DerivedConfigUtils.cpp", "IREEGPUAttrs.cpp", "IREEGPUDialect.cpp", "IREEGPUInterfaces.cpp", "IREEGPUOps.cpp", ], hdrs = [ + "DerivedConfigUtils.h", "IREEGPUAttrs.h", "IREEGPUDialect.h", "IREEGPUEnums.h", @@ -77,9 +79,9 @@ iree_compiler_cc_library( ":IREEGPUInterfaces", ":IREEGPUOpsGen", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", - "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:ConfigUtils", "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils", + "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt index b66ab3c7b089..bebaf9cc05e2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt @@ -14,6 +14,7 @@ iree_cc_library( NAME IREEGPUDialect HDRS + "DerivedConfigUtils.h" "IREEGPUAttrs.h" "IREEGPUDialect.h" "IREEGPUEnums.h" @@ -31,6 +32,7 @@ iree_cc_library( "IREEGPUOps.cpp.inc" "IREEGPUOps.h.inc" SRCS + "DerivedConfigUtils.cpp" "IREEGPUAttrs.cpp" "IREEGPUDialect.cpp" "IREEGPUInterfaces.cpp" @@ -56,9 +58,9 @@ iree_cc_library( MLIRVectorDialect MLIRVectorInterfaces iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect - iree::compiler::Codegen::Dialect::GPU::TargetUtils::ConfigUtils iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect iree::compiler::Codegen::Utils::VectorOpUtils + iree::compiler::Dialect::LinalgExt::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp new file mode 100644 index 000000000000..64af9ddb42c5 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp @@ -0,0 +1,111 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h" +#include + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeUtilities.h" + +namespace mlir::iree_compiler::IREE::GPU { + +static constexpr int64_t kPreferredCopyNumBits = 128; + +SmallVector +getThreadTileSizesFromLoopRanges(SmallVector loopRanges, + int64_t numThreads, int64_t vectorSize) { + // TODO: We shouldn't need this check, however loop fusion currently requires + // loop trip counts to be identical, meaning we need to use a num_threads + // variant of tiling. Remove this and simply return the preferred vector size + // once loop fusion can resolve the forall properly. + if (llvm::any_of(loopRanges, + [](int64_t s) { return ShapedType::isDynamic(s); })) { + return {}; + } + + int64_t flatNumTrips = std::accumulate(loopRanges.begin(), loopRanges.end(), + 1, std::multiplies()); + if (flatNumTrips % numThreads != 0) { + return {}; + } + int64_t maxVectorSize = flatNumTrips / numThreads; + + while (maxVectorSize % vectorSize != 0) { + vectorSize /= 2; + } + + SmallVector tileSizes(loopRanges.size(), 0); + tileSizes.back() = vectorSize; + int64_t residualNumThreads = numThreads / (loopRanges.back() / vectorSize); + for (int i = tileSizes.size() - 2, e = 0; i >= e; --i) { + if (loopRanges[i] >= residualNumThreads) { + tileSizes[i] = loopRanges[i] / residualNumThreads; + residualNumThreads = 1; + break; + } + tileSizes[i] = 1; + residualNumThreads /= loopRanges[i]; + } + + return tileSizes; +} + +SmallVector deriveLinalgOpThreadTileSizes(linalg::LinalgOp linalgOp, + int64_t numThreads) { + if (!linalgOp.hasPureTensorSemantics()) { + return {}; + } + // TODO: Support multi-result + if (linalgOp->getNumResults() != 1) { + return {}; + } + SmallVector loopRanges = linalgOp.getStaticLoopRanges(); + int64_t vectorSize = kPreferredCopyNumBits / + getElementTypeOrSelf(linalgOp->getResultTypes()[0]) + .getIntOrFloatBitWidth(); + return getThreadTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize); +} + +SmallVector +deriveIm2colOpThreadTileSizes(IREE::LinalgExt::Im2colOp im2colOp, + int64_t numThreads) { + if (!im2colOp.hasPureTensorSemantics()) { + return {}; + } + // TODO(Max191): Add `getStaticLoopRanges` to TilingInterface, and use it + // here instead of `im2colOp.getOutputType().getShape()`. Then we can also + // get rid of the specialization for Im2colOp vs LinalgOp and just use + // TilingInterface ops. + SmallVector loopRanges(im2colOp.getOutputType().getShape()); + int64_t vectorSize = kPreferredCopyNumBits / + getElementTypeOrSelf(im2colOp->getResultTypes()[0]) + .getIntOrFloatBitWidth(); + return getThreadTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize); +} + +SmallVector deriveThreadTileSizes(Operation *op) { + std::optional> workgroupSize = + getWorkgroupSize(op->getParentOfType()); + if (!workgroupSize) { + return {}; + } + int64_t numThreads = + std::accumulate(workgroupSize->begin(), workgroupSize->end(), 1, + std::multiplies()); + return TypeSwitch>(op) + .Case([&](linalg::LinalgOp linalgOp) -> SmallVector { + return deriveLinalgOpThreadTileSizes(linalgOp, numThreads); + }) + .Case([&](IREE::LinalgExt::Im2colOp im2colOp) -> SmallVector { + return deriveIm2colOpThreadTileSizes(im2colOp, numThreads); + }) + .Default([](Operation *op) -> SmallVector { return {}; }); +} + +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h new file mode 100644 index 000000000000..68c8174c3017 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h @@ -0,0 +1,18 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_DERIVEDCONFIGUTILS_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_DERIVEDCONFIGUTILS_H_ + +#include "mlir/IR/Operation.h" + +namespace mlir::iree_compiler::IREE::GPU { + +SmallVector deriveThreadTileSizes(Operation *op); + +} // namespace mlir::iree_compiler::IREE::GPU + +#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_DERIVEDCONFIGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index c5f99c0ff031..00ca348a2edf 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -7,10 +7,11 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" -#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Utils/VectorOpUtils.h" #include "llvm/ADT/STLExtras.h" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel index 60760c3414c8..45957a4d73c4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel @@ -21,8 +21,9 @@ iree_compiler_cc_library( "ConfigUtils.h", ], deps = [ + "//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", - "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", + "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt index 34abf4cc3cd9..8cf7e05c614f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt @@ -23,8 +23,9 @@ iree_cc_library( MLIRIR MLIRLinalgDialect MLIRSupport + iree::compiler::Codegen::Common::GPU::GPUHeuristics iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect - iree::compiler::Dialect::LinalgExt::IR + iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 172f9016a10a..656eb39d050d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -5,108 +5,211 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" -#include +#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" -#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "iree-gpu-config-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::iree_compiler::IREE::GPU { -static constexpr int64_t kPreferredCopyNumBits = 128; - -SmallVector -getThreadTileSizesFromLoopRanges(SmallVector loopRanges, - int64_t numThreads, int64_t vectorSize) { - // TODO: We shouldn't need this check, however loop fusion currently requires - // loop trip counts to be identical, meaning we need to use a num_threads - // variant of tiling. Remove this and simply return the preferred vector size - // once loop fusion can resolve the forall properly. - if (llvm::any_of(loopRanges, - [](int64_t s) { return ShapedType::isDynamic(s); })) { - return {}; +LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, + mlir::FunctionOpInterface entryPoint, + Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) { + return failure(); + } + + if (target.getWgp().getMma().empty()) + return failure(); + + const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); + + SmallVector bounds = linalgOp.getStaticLoopRanges(); + FailureOr contractionDims = + mlir::linalg::inferContractionDims(linalgOp); + if (failed(contractionDims)) { + return failure(); } - int64_t flatNumTrips = std::accumulate(loopRanges.begin(), loopRanges.end(), - 1, std::multiplies()); - if (flatNumTrips % numThreads != 0) { - return {}; + if (contractionDims->k.empty() || contractionDims->m.empty() || + contractionDims->n.empty()) { + return failure(); } - int64_t maxVectorSize = flatNumTrips / numThreads; - while (maxVectorSize % vectorSize != 0) { - vectorSize /= 2; + // For now we are not being smart and trying to reshape dimensions to allow + // for better usage of intrinsics, and instead are tiling all dimensions + // except the inner most m, n, and k dimensions to 1. + int64_t mDim = contractionDims->m.back(); + int64_t nDim = contractionDims->n.back(); + int64_t kDim = contractionDims->k.back(); + + // Dynamic dims are expected to be taken care of earlier in the pipeline. + if (ShapedType::isDynamic(bounds[mDim]) || + ShapedType::isDynamic(bounds[nDim]) || + ShapedType::isDynamic(bounds[kDim])) { + return failure(); } - SmallVector tileSizes(loopRanges.size(), 0); - tileSizes.back() = vectorSize; - int64_t residualNumThreads = numThreads / (loopRanges.back() / vectorSize); - for (int i = tileSizes.size() - 2, e = 0; i >= e; --i) { - if (loopRanges[i] >= residualNumThreads) { - tileSizes[i] = loopRanges[i] / residualNumThreads; - residualNumThreads = 1; - break; + Value lhs = linalgOp.getDpsInputOperand(0)->get(); + Value rhs = linalgOp.getDpsInputOperand(1)->get(); + Value init = linalgOp.getDpsInitOperand(0)->get(); + + Type lhsElemType = getElementTypeOrSelf(lhs); + Type rhsElemType = getElementTypeOrSelf(rhs); + Type initElemType = getElementTypeOrSelf(init); + + GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], + lhsElemType, rhsElemType, initElemType}; + + SmallVector intrinsics; + SmallVector supportedMmas; + for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { + IREE::GPU::MMAIntrinsic type = mma.getIntrinsic().getValue(); + // TODO: Drop this once all intrinsics are supported. + if (type != IREE::GPU::MMAIntrinsic::MFMA_F16_16x16x16_F32 && + type != IREE::GPU::MMAIntrinsic::MFMA_I8_16x16x32_I32) { + continue; } - tileSizes[i] = 1; - residualNumThreads /= loopRanges[i]; + supportedMmas.push_back(mma); + + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + if (mma.getSubgroupSize() != targetSubgroupSize) + continue; + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); } - return tileSizes; -} + if (intrinsics.empty()) + return failure(); + + GPUMMAHeuristicSeeds seeds; + + // Note that the following heuristic seeds are just placeholder values. + // We need to clean it up and make it adjusting to different targets. + // See https://github.com/iree-org/iree/issues/16341 for details. + if (problem.mSize * problem.nSize <= 512 * 512) { + // For matmuls with small M*N size, we want to distribute M*N onto more + // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup + // and a larger bestKTileCountPerSubgroup. + seeds = {/*bestSubgroupCountPerWorkgroup=*/4, + /*bestMNTileCountPerSubgroup=*/4, + /*bestKTileCountPerSubgroup=*/8}; + } else { + seeds = {/*bestSubgroupCountPerWorkgroup=*/4, + /*bestMNTileCountPerSubgroup=*/8, + /*bestKTileCountPerSubgroup=*/4}; + } + + int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); -SmallVector deriveLinalgOpThreadTileSizes(linalg::LinalgOp linalgOp, - int64_t numThreads) { - if (!linalgOp.hasPureTensorSemantics()) { - return {}; + LDBG("Matmul TileAndFuse Config"); + + // Infer if lhs or rhs is transposed to help generate better schedule. + // TODO: Drop this. This is only a consideration for other pipelines. + SmallVector maps = linalgOp.getIndexingMapsArray(); + bool transposedLhs = + kDim != + llvm::cast(maps[0].getResults().back()).getPosition(); + bool transposedRhs = + nDim != + llvm::cast(maps[1].getResults().back()).getPosition(); + + // First try to find a schedule with an exactly matching intrinsic. + std::optional schedule = + deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, + targetSubgroupSize, transposedLhs, transposedRhs); + if (!schedule) { + // Then try again by allowing upcasting accumulator. + schedule = deduceMMASchedule( + problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize, + transposedLhs, transposedRhs, /*canUpcastAcc=*/true); } - // TODO: Support multi-result - if (linalgOp->getNumResults() != 1) { - return {}; + + if (!schedule) { + LDBG("Failed to deduce TileAndFuse MMA schedule"); + return failure(); } - SmallVector loopRanges = linalgOp.getStaticLoopRanges(); - int64_t vectorSize = kPreferredCopyNumBits / - getElementTypeOrSelf(linalgOp->getResultTypes()[0]) - .getIntOrFloatBitWidth(); - return getThreadTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize); -} -SmallVector -deriveIm2colOpThreadTileSizes(IREE::LinalgExt::Im2colOp im2colOp, - int64_t numThreads) { - if (!im2colOp.hasPureTensorSemantics()) { - return {}; + LDBG("Target Subgroup size: " << targetSubgroupSize); + LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " + << schedule->kSize << "]"); + LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " + << schedule->nTileCount << ", " + << schedule->kTileCount << "]"); + LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " + << schedule->nWarpCount << "]"); + + std::array workgroupSize{ + schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + + SmallVector workgroupTileSizes(linalgOp.getNumLoops(), 0); + SmallVector reductionTileSizes(linalgOp.getNumLoops(), 0); + SmallVector subgroupTileSizes(linalgOp.getNumLoops(), 0); + // Tile all batch dimensions with unit size. + for (int64_t batch : contractionDims->batch) { + workgroupTileSizes[batch] = 1; } - // TODO(Max191): Add `getStaticLoopRanges` to TilingInterface, and use it - // here instead of `im2colOp.getOutputType().getShape()`. Then we can also - // get rid of the specialization for Im2colOp vs LinalgOp and just use - // TilingInterface ops. - SmallVector loopRanges(im2colOp.getOutputType().getShape()); - int64_t vectorSize = kPreferredCopyNumBits / - getElementTypeOrSelf(im2colOp->getResultTypes()[0]) - .getIntOrFloatBitWidth(); - return getThreadTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize); -} -SmallVector deriveThreadTileSizes(Operation *op) { - std::optional> workgroupSize = - getWorkgroupSize(op->getParentOfType()); - if (!workgroupSize) { - return {}; + // Tile all m, n, and k dimensions to 1 except the innermost. Unit dims + // from this tiling are folded before vectorization. + for (int64_t m : llvm::drop_end(contractionDims->m)) { + workgroupTileSizes[m] = 1; + } + for (int64_t n : llvm::drop_end(contractionDims->n)) { + workgroupTileSizes[n] = 1; } - int64_t numThreads = - std::accumulate(workgroupSize->begin(), workgroupSize->end(), 1, - std::multiplies()); - return TypeSwitch>(op) - .Case([&](linalg::LinalgOp linalgOp) -> SmallVector { - return deriveLinalgOpThreadTileSizes(linalgOp, numThreads); - }) - .Case([&](IREE::LinalgExt::Im2colOp im2colOp) -> SmallVector { - return deriveIm2colOpThreadTileSizes(im2colOp, numThreads); - }) - .Default([](Operation *op) -> SmallVector { return {}; }); + for (int64_t k : llvm::drop_end(contractionDims->k)) { + reductionTileSizes[k] = 1; + } + + // Compute the M/N dimension tile size by multiplying subgroup information. + workgroupTileSizes[mDim] = + schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + workgroupTileSizes[nDim] = + schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + + // Specify the subgroup tile sizes from the mma schedule. This is applied + subgroupTileSizes[mDim] = schedule->mTileCount; + subgroupTileSizes[nDim] = schedule->nTileCount; + + // Similarly the reduction tile size is just the post-packing tile count. + reductionTileSizes[kDim] = schedule->kTileCount; + + IREE::GPU::MmaInterfaceAttr mmaKind = supportedMmas[schedule->index]; + + // Attach the MMA schedule as an attribute to the entry point export function + // for later access in the pipeline. + MLIRContext *context = linalgOp.getContext(); + SmallVector attrs; + Builder b(context); + attrs.emplace_back(StringAttr::get(context, "workgroup"), + b.getIndexArrayAttr(workgroupTileSizes)); + attrs.emplace_back(StringAttr::get(context, "reduction"), + b.getIndexArrayAttr(reductionTileSizes)); + attrs.emplace_back(StringAttr::get(context, "subgroup"), + b.getIndexArrayAttr(subgroupTileSizes)); + attrs.emplace_back(StringAttr::get(context, "mma_kind"), mmaKind); + auto configDict = DictionaryAttr::get(context, attrs); + auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); + + // TODO(qedawkins): Use a shared pipeline identifier here. + return setOpConfigAndEntryPointFnTranslation( + entryPoint, op, loweringConfig, + IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse, + workgroupSize, targetSubgroupSize); } } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h index ed29a51ed607..913403193fe0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h @@ -7,11 +7,18 @@ #ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_TARGETUTILS_CONFIGUTILS_H_ #define IREE_COMPILER_CODEGEN_DIALECT_GPU_TARGETUTILS_CONFIGUTILS_H_ +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" namespace mlir::iree_compiler::IREE::GPU { -SmallVector deriveThreadTileSizes(Operation *op); +/// Helper for setting up a matmul config based on the specified target. +/// TODO: Currently this only succeeds if the target supports an mma +/// kind. Add support for a fallback direct lowering path. +LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, + mlir::FunctionOpInterface entryPoint, + Operation *op); } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 15ccb3c562ee..3db67b83dda1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -1898,7 +1898,7 @@ setDefaultGenericOpRootConfig(mlir::FunctionOpInterface entryPointFn, unsigned numLoops = genericOp.getNumLoops(); if (numLoops == 0) { return setOpConfigAndEntryPointFnTranslation( - entryPointFn, genericOp, {{}}, + entryPointFn, genericOp, TileSizesListType{{}}, DispatchLoweringPassPipeline::CPUDefault); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index aeeee5e777c0..c57e1f585186 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -132,6 +132,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:ConfigUtils", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms", "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", "//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index f57af96e8081..30e722e39307 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -173,6 +173,7 @@ iree_cc_library( iree::compiler::Codegen::Common::VectorLayoutAnalysis iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Dialect::GPU::TargetUtils::ConfigUtils iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index b8c317fafeee..d7ee4e04a5f5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" #include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h" #include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" @@ -45,6 +46,11 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::iree_compiler { +llvm::cl::opt clGPUEnableTileAndFuse( + "iree-codegen-llvmgpu-use-tile-and-fuse", + llvm::cl::desc("enable the usage of the tile and fuse pipeline"), + llvm::cl::init(false)); + llvm::cl::opt clGPUEnableVectorDistribution( "iree-codegen-llvmgpu-use-vector-distribution", llvm::cl::desc("enable the usage of the vector distribution pipeline"), @@ -1940,6 +1946,11 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target, LDBG("Transform Dialect Config"); return success(); } + if (clGPUEnableTileAndFuse && succeeded(IREE::GPU::setMatmulLoweringConfig( + target, entryPointFn, computeOp))) { + LDBG("Tile and fuse matmul config"); + return success(); + } if (succeeded(setVectorDistributionConfig(target, entryPointFn, computeOp))) { return success(); } @@ -2046,12 +2057,17 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) { } SmallVector computeOps = getComputeOps(funcOp); - if (getTranslationInfo(funcOp)) { - // Currently LLVMGPU requires propagation of user lowering configs. - for (auto op : computeOps) { - if (getLoweringConfig(op)) { - propagateLoweringConfig(op, computeOps); - break; + if (IREE::Codegen::TranslationInfoAttr translationInfo = + getTranslationInfo(funcOp)) { + // Currently ROCDL requires propagation of user lowering configs for + // all pipelines except TileAndFuse. + if (translationInfo.getDispatchLoweringPassPipeline() != + IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) { + for (auto op : computeOps) { + if (getLoweringConfig(op)) { + propagateLoweringConfig(op, computeOps); + break; + } } } return success(); @@ -2097,6 +2113,16 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) { if (failed(setRootConfig(target, funcOp, rootOperation))) return funcOp.emitOpError("failed to set root config"); + if (IREE::Codegen::TranslationInfoAttr translationInfo = + getTranslationInfo(funcOp)) { + // Currently ROCDL requires propagation of user lowering configs for + // all pipelines except TileAndFuse. + if (translationInfo.getDispatchLoweringPassPipeline() == + IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) { + return success(); + } + } + propagateLoweringConfig(rootOperation, computeOps); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp index d05fe0ba01d2..2cd705dc844d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" #include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" @@ -272,6 +273,10 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPointFn, Operation *computeOp) { if (auto linalgOp = dyn_cast(computeOp)) { + if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn, + linalgOp))) { + return success(); + } if (succeeded(setWarpReductionConfig(target, entryPointFn, linalgOp))) { return success(); } @@ -326,12 +331,17 @@ LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp) { } SmallVector computeOps = getComputeOps(funcOp); - if (getTranslationInfo(funcOp)) { - // Currently ROCDL requires propagation of user lowering configs. - for (auto op : computeOps) { - if (getLoweringConfig(op)) { - propagateLoweringConfig(op, computeOps); - break; + if (IREE::Codegen::TranslationInfoAttr translationInfo = + getTranslationInfo(funcOp)) { + // Currently ROCDL requires propagation of user lowering configs for + // all pipelines except TileAndFuse. + if (translationInfo.getDispatchLoweringPassPipeline() != + IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) { + for (auto op : computeOps) { + if (getLoweringConfig(op)) { + propagateLoweringConfig(op, computeOps); + break; + } } } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index c995397d8e45..978d1cf3a896 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "config_tile_and_fuse.mlir", "config_vector_distribute.mlir", "config_user_vector_distribute.mlir", "lowering_scalar_dispatch.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt index d80f1d605b2f..b4c1bf6e43b0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "config_tile_and_fuse.mlir" "config_user_vector_distribute.mlir" "config_vector_distribute.mlir" "lowering_scalar_dispatch.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir new file mode 100644 index 000000000000..2ac9f34f5209 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -0,0 +1,58 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 \ +// RUN: --iree-codegen-llvmgpu-use-tile-and-fuse --iree-codegen-llvmgpu-use-vector-distribution=false \ +// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf16> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : tensor<2x10x64x64xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16> + %7 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor<2x10x64x64xf16> + return %7 : tensor<2x10x64x64xf16> +} + +// CHECK: #iree_codegen.translation_info +// CHECK-LABEL: func.func @expanded_matmul_transpose_b + +// Verify that the fill does not have the lowering config propagated to it. +// CHECK: linalg.fill ins + +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: reduction = [0 : index, 0 : index, 0 : index, 0 : index, 8 : index] +// CHECK-SAME: subgroup = [0 : index, 0 : index, 4 : index, 1 : index, 0 : index] +// CHECK-SAME: workgroup = [1 : index, 1 : index, 64 : index, 64 : index, 0 : index] + +// ----- + +func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<1024x1024xf16>) -> tensor<1024x1024xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %5 = tensor.empty() : tensor<1024x1024xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %7 = linalg.matmul ins(%lhs, %rhs : tensor<1024x1024xf16>, tensor<1024x1024xf16>) outs(%6 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %7 : tensor<1024x1024xf32> +} + +// CHECK: #iree_codegen.translation_info +// CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024 + +// Verify that the fill does not have the lowering config propagated to it. +// CHECK: linalg.fill ins + +// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: reduction = [0 : index, 0 : index, 4 : index] +// CHECK-SAME: subgroup = [2 : index, 4 : index, 0 : index] +// CHECK-SAME: workgroup = [64 : index, 128 : index, 0 : index] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 5a13996a9e84..b16dca1a666b 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -1312,8 +1312,8 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, // single thread to run everything. auto pipeline = CodeGenPipeline::SPIRVBaseDistribute; std::array workgroupSize = {1, 1, 1}; - return setOpConfigAndEntryPointFnTranslation(funcOp, op, {}, pipeline, - workgroupSize); + return setOpConfigAndEntryPointFnTranslation( + funcOp, op, TileSizesListType{}, pipeline, workgroupSize); } int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);