Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] Sprinkle in PropagateDispatchSizeBounds passes #19677

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ static std::pair<Value, Value> makeTransposedIds(Location loc, OpBuilder b,
/// Returns the workgroup counts along the X and Y dimensions. These will be
/// constants when static in the corresponding `hal.executable.export` op.
static std::pair<Value, Value>
getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp) {
getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp,
std::optional<APInt> xBound, std::optional<APInt> yBound) {
Location loc = funcOp.getLoc();
SmallVector<int64_t> workgroupCounts = getStaticNumWorkgroups(funcOp);
bool isStaticWgCount = llvm::none_of(workgroupCounts, ShapedType::isDynamic);
Expand All @@ -62,9 +63,9 @@ getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp) {

LLVM_DEBUG(llvm::dbgs() << "Using dynamic workgroup counts\n");
Value dynamicCountX =
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0);
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0, xBound);
Value dynamicCountY =
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1);
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1, yBound);
return {dynamicCountX, dynamicCountY};
}

Expand Down Expand Up @@ -100,11 +101,12 @@ reorderWorkgroupsInFunc(FunctionOpInterface funcOp,
// that to RAUW the old ones. This way we don't have to worry about the
// picking the exact insertion points that do not violate dominance between
// their defs and users.
Value workgroupIdX =
builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 0);
Value workgroupIdY =
builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 1);
auto [workgroupCntX, workgroupCntY] = getWorkgroupCountsXY(builder, funcOp);
Value workgroupIdX = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
funcOp.getLoc(), 0, oldXId.getUpperBound());
Value workgroupIdY = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
funcOp.getLoc(), 1, oldYId.getUpperBound());
auto [workgroupCntX, workgroupCntY] = getWorkgroupCountsXY(
builder, funcOp, oldXId.getUpperBound(), oldYId.getUpperBound());
Value newWorkgroupIdX;
Value newWorkgroupIdY;
assert(strategy == ReorderWorkgroupsStrategy::Transpose &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void registerCodegenInterfaces(DialectRegistry &registry) {
affine::registerValueBoundsOpInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::registerTransformDialectExtension(registry);
gpu::registerValueBoundsOpInterfaceExternalModels(registry);
gpu::registerTransformDialectExtension(registry);
gpu::registerValueBoundsOpInterfaceExternalModels(registry);
linalg::registerTransformDialectExtension(registry);
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
}

//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -447,6 +448,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
addCPUBufferizePasses(funcPassManager);

// Run IREE specific passes before vector lowering expert.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());

{
Expand Down Expand Up @@ -510,6 +512,7 @@ void addConvTileAndDecomposeExpertPassPipeline(
addCPUBufferizePasses(funcPassManager);

// Run IREE specific passes before vector lowering expert.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());

{
Expand Down
17 changes: 14 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUDistributePass());

// Post bufferization optimizations.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -439,6 +440,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createTileLargeTensorsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
funcPassManager.addPass(IREE::GPU::createCombineBarrierRegionsPass());

Expand Down Expand Up @@ -468,6 +470,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

// Step 9. Remaining post-bufferization optimizations/lowerings.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass());
funcPassManager.addPass(createUnrollAnnotatedLoopsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
Expand Down Expand Up @@ -524,6 +527,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUDistributeScfForPass(options));

// Post bufferization optimizations.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
Expand All @@ -544,6 +548,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
// Distribute linalg onto warps within the workgroup.
funcPassManager.addPass(
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/true));
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
if (pipelineDepth > 1) {
funcPassManager.addPass(createGPUMultiBufferingPass(
Expand Down Expand Up @@ -589,6 +594,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

// Hoist loop invariant code to avoid pipelining it.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
// Pipeline memory operations.
GPUPipeliningPassOptions pipelieningOptions = {};
Expand All @@ -613,6 +619,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
// Distribute linalg onto warps within the workgroup.
funcPassManager.addPass(
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/true));
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
if (pipelineDepth > 1) {
funcPassManager.addPass(createGPUMultiBufferingPass(
Expand Down Expand Up @@ -655,6 +662,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
funcPassManager.addPass(createCSEPass());

// Hoist loop invariant code to avoid pipelining it.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
// Pipeline memory operations.
GPUPipeliningPassOptions pipelieningOptions = {};
Expand Down Expand Up @@ -882,6 +890,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUTileReductionPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());

// Linalg -> vector
{
Expand Down Expand Up @@ -949,6 +958,7 @@ void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
}

Expand All @@ -965,6 +975,7 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

addBufferizePasses(funcPassManager);
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
}

Expand All @@ -981,6 +992,7 @@ void addGPUBaseLoweringPassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(IREE::LinalgExt::createLinalgExtToLoopsPass());
funcPassManager.addPass(createMemrefCopyToLinalgPass());
funcPassManager.addPass(createConvertLinalgToLoopsPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
Expand All @@ -999,6 +1011,7 @@ addLowerAndOptimizeAddressComputationPasses(FunctionLikeNest &funcPassManager) {
.addPass(memref::createExpandOpsPass)
.addPass(memref::createFoldMemRefAliasOpsPass)
.addPass(memref::createExpandStridedMetadataPass)
.addPass(createPropagateDispatchSizeBoundsPass)
// Hoist loop invariant variables to give affine decomposition pass the
// right loop dependencies.
.addPass(createIREELoopInvariantCodeMotionPass)
Expand Down Expand Up @@ -1055,9 +1068,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
FunctionLikeNest funcPassManager(modulePassManager);
funcPassManager.addPass(createFoldTensorExtractOpPass)
.addPass(createLLVMGPUVectorLoweringPass)
.addPass(createExpandGPUOpsPass)
// Expose workitem and workgroup counts to range inference later.
.addPass(createPropagateDispatchSizeBoundsPass);
.addPass(createExpandGPUOpsPass);

// This pass needs to run before SCF -> CF.
addLowerAndOptimizeAddressComputationPasses(funcPassManager);
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ static void addLoopMaterializationPasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(IREE::LinalgExt::createLinalgExtToLoopsPass());
funcPassManager.addPass(createMemrefCopyToLinalgPass());
funcPassManager.addPass(createConvertLinalgToLoopsPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
Expand Down Expand Up @@ -394,6 +395,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(
funcPassManager.addPass(
createSPIRVTileAndPromotePass(SPIRVTileAndPromotePassOptions{
/*promoteCMatrix=*/true, /*skipThreadLevel=*/true}));
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
// Run canonicalization patterns to propagate constant shape sizes after
// removing trip-one loops.
Expand Down Expand Up @@ -421,6 +423,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(
funcPassManager.addPass(createGPUReduceBankConflictsPass(options));
}

funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
// Performs high-level n-D mechanical vectorization. This does not perform
// unrolling or lowering, which is done later.
{
Expand Down Expand Up @@ -513,6 +516,7 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());

{
GPUReduceBankConflictsPassOptions options = {};
Expand All @@ -532,6 +536,7 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createForOpCanonicalizationPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createOptimizeVectorTransferPass());

// Hoist loop invariant code to avoid pipelining it.
Expand Down Expand Up @@ -560,6 +565,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUTileReductionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());

// Performs high-level n-D mechanical vectorization. This does not perform
// unrolling or lowering, which is done later.
Expand Down Expand Up @@ -588,6 +594,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) {

// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createOptimizeVectorTransferPass());

// Simplify the IR for vector distribution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ func.func @warp_reduction_dispatch() attributes {hal.executable.target = #execut

// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f16

// CHECK-DAG: %[[WGIDX:.+]] = hal.interface.workgroup.id[0] : index
// CHECK-DAG: %[[WGIDY:.+]] = hal.interface.workgroup.id[1] : index
// CHECK-DAG: %[[WGIDX:.+]] = hal.interface.workgroup.id[0] upper_bound 65535 : index
// CHECK-DAG: %[[WGIDY:.+]] = hal.interface.workgroup.id[1] upper_bound 65535 : index
// CHECK-DAG: %[[TIDX:.+]] = gpu.thread_id x

// CHECK-DAG: %[[SPAN0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager,
addCPUBufferizePasses(funcPassManager);

// Cleanup the IR that may now have unused loops.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());

// Convert buffer-level microkernels.
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3039,9 +3039,10 @@ class HAL_InterfaceWorkgroupOp<string mnemonic, list<Trait> traits = []>
let results = (outs HAL_Dim:$result);

let builders = [
OpBuilder<(ins "unsigned":$dim),
OpBuilder<(ins "unsigned":$dim, CArg<"std::optional<::llvm::APInt>", "std::nullopt">:$upper_bound),
[{
build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim), ::mlir::IntegerAttr{});
build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim),
upper_bound.has_value() ? $_builder.getIndexAttr(upper_bound->getSExtValue()) : ::mlir::IntegerAttr{});
}]>,
];

Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ iree_compiler_cc_library(
name = "ExternalModels",
srcs = [
"FlowExternalModels.cpp",
"HALExternalModels.cpp",
"Interfaces.cpp",
"LinalgExtExternalModels.cpp",
"StreamExternalModels.cpp",
"UtilExternalModels.cpp",
],
hdrs = [
"FlowExternalModels.h",
"HALExternalModels.h",
"Interfaces.h",
"LinalgExtExternalModels.h",
"StreamExternalModels.h",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ iree_cc_library(
ExternalModels
HDRS
"FlowExternalModels.h"
"HALExternalModels.h"
"Interfaces.h"
"LinalgExtExternalModels.h"
"StreamExternalModels.h"
"UtilExternalModels.h"
SRCS
"FlowExternalModels.cpp"
"HALExternalModels.cpp"
"Interfaces.cpp"
"LinalgExtExternalModels.cpp"
"StreamExternalModels.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/ExternalInterfaces/HALExternalModels.h"

#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

namespace mlir::iree_compiler {

namespace {

//===----------------------------------------------------------------------===//
// ValueBoundsOpInterface
//===----------------------------------------------------------------------===//

template <typename IDOp>
struct IDOpValueBoundsInterface : public ValueBoundsOpInterface::ExternalModel<
IDOpValueBoundsInterface<IDOp>, IDOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto boundOp = cast<IDOp>(op);
assert(value == boundOp.getResult() && "value must be op result");
cstr.bound(value) >= 0;
if (boundOp.getUpperBound()) {
cstr.bound(value) < boundOp.getUpperBound()->getSExtValue();
}
}
};

template <typename CountOp>
struct CountOpValueBoundsInterface
: public ValueBoundsOpInterface::ExternalModel<
CountOpValueBoundsInterface<CountOp>, CountOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto boundOp = cast<CountOp>(op);
assert(value == boundOp.getResult() && "value must be op result");
cstr.bound(value) >= 1;
if (boundOp.getUpperBound()) {
cstr.bound(value) <= boundOp.getUpperBound()->getSExtValue();
}
}
};

} // namespace

void registerHALExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context,
IREE::HAL::HALDialect *dialect) {
IREE::HAL::InterfaceWorkgroupIDOp::attachInterface<
IDOpValueBoundsInterface<IREE::HAL::InterfaceWorkgroupIDOp>>(*context);

IREE::HAL::InterfaceWorkgroupSizeOp::attachInterface<
CountOpValueBoundsInterface<IREE::HAL::InterfaceWorkgroupSizeOp>>(
*context);
IREE::HAL::InterfaceWorkgroupCountOp::attachInterface<
CountOpValueBoundsInterface<IREE::HAL::InterfaceWorkgroupCountOp>>(
*context);
});
}
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_EXTERNALINTERFACES_HALEXTERNALMODELS_H_
#define IREE_COMPILER_EXTERNALINTERFACES_HALEXTERNALMODELS_H_

namespace mlir {
class DialectRegistry;
} // namespace mlir

namespace mlir::iree_compiler {

void registerHALExternalModels(DialectRegistry &registry);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_EXTERNALINTERFACES_HALEXTERNALMODELS_H_
Loading
Loading