Skip to content

Commit

Permalink
[LLVMGPU] Extract subgroup size from export op to use for vector dist…
Browse files Browse the repository at this point in the history
…ribution (iree-org#14826)

Previously subgroup size was hard coded to 32. This extracts the
subgroup size from the `hal.executable.export` op associated with the
target.
  • Loading branch information
qedawkins authored Aug 26, 2023
1 parent 585e5ca commit 87b920e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
Expand Down Expand Up @@ -79,24 +80,13 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne(
transform::TransformRewriter &rewriter, func::FuncOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
"requires HAL::ExecutableOp or HAL::ExecutableVariantOp "
"toplevel to "
"attach the workgroup size information to a nested "
"ExecutableExportOp");
return emitDefaultDefiniteFailure(target);
}

IREE::HAL::ExecutableExportOp exportOp;
state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) {
if (op.getSymName() == target.getName())
exportOp = op;
});
if (!exportOp) {
FailureOr<IREE::HAL::ExecutableExportOp> maybeExportOp =
getEntryPoint(target);
if (failed(maybeExportOp)) {
state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found");
return emitDefaultDefiniteFailure(target);
}
IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp;

auto transformOp = cast<transform::TransformOpInterface>(getOperation());

Expand Down Expand Up @@ -574,11 +564,13 @@ static Value simpleWarpShuffleFunction(Location loc, OpBuilder &builder,

static void populatePropagateVectorDistribution(Operation *target,
RewritePatternSet &patterns,
PatternBenefit benefit) {
auto groupReductionFn = [](Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size) {
PatternBenefit benefit,
unsigned subgroupSize) {
auto groupReductionFn = [subgroupSize](
Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size) {
return mlir::iree_compiler::emitGPUGroupReduction(loc, builder, input, kind,
size, 32);
size, subgroupSize);
};
assert(target->hasTrait<OpTrait::IsIsolatedFromAbove>());
vector::populatePropagateWarpVectorDistributionPatterns(
Expand All @@ -604,14 +596,21 @@ static void populateWarpExecuteOnLane0ToScf(

DiagnosedSilenceableFailure
transform_dialect::VectorWarpDistributionOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::TransformRewriter &rewriter, func::FuncOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
target->emitOpError(
"applies only to isolated-from-above targets because it "
"needs to apply "
"patterns greedily");
FailureOr<IREE::HAL::ExecutableExportOp> maybeExportOp =
getEntryPoint(target);
if (failed(maybeExportOp)) {
state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found");
return emitDefaultDefiniteFailure(target);
}
IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp;

std::optional<llvm::APInt> subgroupSize = exportOp.getSubgroupSize();
if (!subgroupSize) {
state.getTopLevel()->emitOpError(
"could not extract subgroup size from IREE::HAL::ExecutableExportOp");
return emitDefaultDefiniteFailure(target);
}

Expand Down Expand Up @@ -645,7 +644,8 @@ transform_dialect::VectorWarpDistributionOp::applyToOne(
populateVectorTransferWriteDistribution(target, patterns,
/*benefit=*/2);
populatePropagateVectorDistribution(target, patterns,
/*benefit=*/1);
/*benefit=*/1,
subgroupSize->getSExtValue());
if (failed(
applyPatternsAndFoldGreedily(target, std::move(patterns), config))) {
return mlir::emitDefiniteFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def VectorWarpDistributionOp : Op<Transform_Dialect, "iree.vector.warp_distribut
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::func::FuncOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

hal.executable private @reduce_dispatch_0 {
hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
hal.executable.export public @reduce_dispatch_0 ordinal(0) layout(#pipeline_layout) attributes { workgroup_size = [64: index, 1: index, 1: index] }
hal.executable.export public @reduce_dispatch_0 ordinal(0) layout(#pipeline_layout) attributes { workgroup_size = [64: index, 1: index, 1: index], subgroup_size = 32 : index }
builtin.module {
func.func @reduce_dispatch_0() {
%c0 = arith.constant 0 : index
Expand Down

0 comments on commit 87b920e

Please sign in to comment.