From 87b920ea9b036a2a359b4c3b3ebda7caa45adad9 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Sat, 26 Aug 2023 01:41:28 -0400 Subject: [PATCH] [LLVMGPU] Extract subgroup size from export op to use for vector distribution (#14826) Previously subgroup size was hard coded to 32. This extracts the subgroup size from the `hal.executable.export` op associated with the target. --- .../TransformExtensions/LLVMGPUExtensions.cpp | 52 +++++++++---------- .../LLVMGPUExtensionsOps.td | 2 +- ...transform_dialect_vector_distribution.mlir | 2 +- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 83b38444f237..49dd178f5fbc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -79,24 +80,13 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne( transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - if (!isa(state.getTopLevel())) { - state.getTopLevel()->emitOpError( - "requires HAL::ExecutableOp or HAL::ExecutableVariantOp " - "toplevel to " - "attach the workgroup size information to a nested " - "ExecutableExportOp"); - return emitDefaultDefiniteFailure(target); - } - - IREE::HAL::ExecutableExportOp exportOp; - state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) { - if (op.getSymName() == target.getName()) - exportOp = op; - }); - if (!exportOp) { + FailureOr maybeExportOp = + getEntryPoint(target); + if (failed(maybeExportOp)) { state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); return emitDefaultDefiniteFailure(target); } + IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp; auto transformOp = cast(getOperation()); @@ -574,11 +564,13 @@ static Value simpleWarpShuffleFunction(Location loc, OpBuilder &builder, static void populatePropagateVectorDistribution(Operation *target, RewritePatternSet &patterns, - PatternBenefit benefit) { - auto groupReductionFn = [](Location loc, OpBuilder &builder, Value input, - vector::CombiningKind kind, uint32_t size) { + PatternBenefit benefit, + unsigned subgroupSize) { + auto groupReductionFn = [subgroupSize]( + Location loc, OpBuilder &builder, Value input, + vector::CombiningKind kind, uint32_t size) { return mlir::iree_compiler::emitGPUGroupReduction(loc, builder, input, kind, - size, 32); + size, subgroupSize); }; assert(target->hasTrait()); vector::populatePropagateWarpVectorDistributionPatterns( @@ -604,14 +596,21 @@ static void populateWarpExecuteOnLane0ToScf( DiagnosedSilenceableFailure transform_dialect::VectorWarpDistributionOp::applyToOne( - transform::TransformRewriter &rewriter, Operation *target, + transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - if (!target->hasTrait()) { - target->emitOpError( - "applies only to isolated-from-above targets because it " - "needs to apply " - "patterns greedily"); + FailureOr maybeExportOp = + getEntryPoint(target); + if (failed(maybeExportOp)) { + state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); + return emitDefaultDefiniteFailure(target); + } + IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp; + + std::optional subgroupSize = exportOp.getSubgroupSize(); + if (!subgroupSize) { + state.getTopLevel()->emitOpError( + "could not extract subgroup size from IREE::HAL::ExecutableExportOp"); return emitDefaultDefiniteFailure(target); } @@ -645,7 +644,8 @@ transform_dialect::VectorWarpDistributionOp::applyToOne( populateVectorTransferWriteDistribution(target, patterns, /*benefit=*/2); populatePropagateVectorDistribution(target, patterns, - /*benefit=*/1); + /*benefit=*/1, + subgroupSize->getSExtValue()); if (failed( applyPatternsAndFoldGreedily(target, std::move(patterns), config))) { return mlir::emitDefiniteFailure( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 9c2f6ee94c56..9c0d0b485ef9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -325,7 +325,7 @@ def VectorWarpDistributionOp : Op