Skip to content

Commit

Permalink
[GPU][NFC] Move gpu pipeline transformations out of iree_gpu dialect (i…
Browse files Browse the repository at this point in the history
…ree-org#19248)

This patch moves gpu pipeline specific passes out of iree_gpu dialect to
Codegen/Common/GPU. The idea is that passes that are specific to gpu
pipelines should live in Codegen/Common/GPU, while passes specific to
iree_gpu operations should live in Codegen/Dialect/GPU. This allows us
to make iree_gpu dialect not too bloated and depend on other dialects.
  • Loading branch information
Groverkss authored Nov 22, 2024
1 parent a467b73 commit 4a5187d
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 47 deletions.
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ iree_compiler_cc_library(
"GPUDistributeScfFor.cpp",
"GPUDistributeSharedMemoryCopy.cpp",
"GPUDistributionPatterns.cpp",
"GPUFuseAndHoistParallelLoops.cpp",
"GPUGeneralizeNamedOps.cpp",
"GPUGreedilyDistributeToThreads.cpp",
"GPUInferMemorySpace.cpp",
"GPULowerToUKernels.cpp",
"GPUMaterializeEncoding.cpp",
"GPUMultiBuffering.cpp",
"GPUNestedLayoutDistributionPatterns.cpp",
"GPUPackToIntrinsics.cpp",
"GPUPatterns.cpp",
"GPUPipelining.cpp",
"GPUPromoteMatmulOperands.cpp",
Expand Down Expand Up @@ -130,6 +132,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:NVGPUDialect",
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ iree_cc_library(
"GPUDistributeScfFor.cpp"
"GPUDistributeSharedMemoryCopy.cpp"
"GPUDistributionPatterns.cpp"
"GPUFuseAndHoistParallelLoops.cpp"
"GPUGeneralizeNamedOps.cpp"
"GPUGreedilyDistributeToThreads.cpp"
"GPUInferMemorySpace.cpp"
"GPULowerToUKernels.cpp"
"GPUMaterializeEncoding.cpp"
"GPUMultiBuffering.cpp"
"GPUNestedLayoutDistributionPatterns.cpp"
"GPUPackToIntrinsics.cpp"
"GPUPatterns.cpp"
"GPUPipelining.cpp"
"GPUPromoteMatmulOperands.cpp"
Expand Down Expand Up @@ -105,6 +107,7 @@ iree_cc_library(
MLIRLinalgDialect
MLIRLinalgTransforms
MLIRLinalgUtils
MLIRLoopLikeInterface
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRNVGPUDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,21 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-gpu-fuse-and-hoist-parallel-loops"
#define DEBUG_TYPE "iree-codegen-gpu-fuse-and-hoist-parallel-loops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir::iree_compiler::IREE::GPU {
namespace mlir::iree_compiler {

#define GEN_PASS_DEF_FUSEANDHOISTPARALLELLOOPSPASS
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
#define GEN_PASS_DEF_GPUFUSEANDHOISTPARALLELLOOPSPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

using namespace IREE::GPU;

namespace {
struct FuseAndHoistParallelLoopsPass final
: impl::FuseAndHoistParallelLoopsPassBase<FuseAndHoistParallelLoopsPass> {
struct GPUFuseAndHoistParallelLoopsPass final
: impl::GPUFuseAndHoistParallelLoopsPassBase<
GPUFuseAndHoistParallelLoopsPass> {
void runOnOperation() override;
};
} // namespace
Expand Down Expand Up @@ -322,7 +325,7 @@ struct FuseTilableForallConsumers final
}
};

void FuseAndHoistParallelLoopsPass::runOnOperation() {
void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();

FunctionOpInterface funcOp = getOperation();
Expand Down Expand Up @@ -390,4 +393,4 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
LDBG("After fusing new producers\n" << funcOp);
}

} // namespace mlir::iree_compiler::IREE::GPU
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::GPU {
namespace mlir::iree_compiler {

#define GEN_PASS_DEF_PACKTOINTRINSICSPASS
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
#define GEN_PASS_DEF_GPUPACKTOINTRINSICSPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

namespace {
struct PackToIntrinsicsPass final
: impl::PackToIntrinsicsPassBase<PackToIntrinsicsPass> {
struct GPUPackToIntrinsicsPass final
: impl::GPUPackToIntrinsicsPassBase<GPUPackToIntrinsicsPass> {
void runOnOperation() override;
};
} // namespace
Expand Down Expand Up @@ -90,7 +89,7 @@ struct ConvertToMultiMma final : OpInterfaceRewritePattern<linalg::LinalgOp> {
}
};

void PackToIntrinsicsPass::runOnOperation() {
void GPUPackToIntrinsicsPass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();

Expand Down Expand Up @@ -143,4 +142,4 @@ void PackToIntrinsicsPass::runOnOperation() {
}
}

} // namespace mlir::iree_compiler::IREE::GPU
} // namespace mlir::iree_compiler
19 changes: 19 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def GPUDistributeScfForPass :
];
}

def GPUFuseAndHoistParallelLoopsPass :
InterfacePass<"iree-codegen-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
let summary = "Greedily fuses and hoists parallel loops.";
let dependentDialects = [
"::mlir::affine::AffineDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::bufferization::BufferizationDialect"
];
}

def GPUGeneralizeNamedOpsPass :
InterfacePass<"iree-codegen-gpu-generalize-named-ops", "mlir::FunctionOpInterface"> {
let summary = "Convert named Linalg ops to linalg.generic ops";
Expand Down Expand Up @@ -115,6 +125,15 @@ def GPUMultiBufferingPass :
];
}

def GPUPackToIntrinsicsPass :
InterfacePass<"iree-codegen-gpu-pack-to-intrinsics", "mlir::FunctionOpInterface"> {
let summary = "Packs matmul like operations and converts to iree_gpu.multi_mma";
let dependentDialects = [
"::mlir::tensor::TensorDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
];
}

def GPUPipeliningPass :
InterfacePass<"iree-codegen-gpu-pipelining", "mlir::FunctionOpInterface"> {
let summary = "Pass to do software pipelining.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ iree_lit_test_suite(
"gpu_distribute_scf_for.mlir",
"gpu_distribute_shared_memory.mlir",
"gpu_generalize_named_ops.mlir",
"gpu_pack_to_instrinsics.mlir",
"gpu_fuse_and_hoist_forall.mlir",
"gpu_greedily_distribute_to_threads.mlir",
"gpu_infer_memory_space.mlir",
"gpu_lower_to_ukernels.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ iree_lit_test_suite(
"gpu_distribute_forall.mlir"
"gpu_distribute_scf_for.mlir"
"gpu_distribute_shared_memory.mlir"
"gpu_fuse_and_hoist_forall.mlir"
"gpu_generalize_named_ops.mlir"
"gpu_greedily_distribute_to_threads.mlir"
"gpu_infer_memory_space.mlir"
Expand All @@ -33,6 +34,7 @@ iree_lit_test_suite(
"gpu_nested_layout_contract_amdgpu.mlir"
"gpu_nested_layout_vector_distribution.mlir"
"gpu_nested_layout_vector_distribution_step.mlir"
"gpu_pack_to_instrinsics.mlir"
"gpu_pipeline.mlir"
"gpu_promote_matmul_operands.mlir"
"gpu_reorder_workgroups.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s

#translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s --mlir-print-local-scope --pass-pipeline='builtin.module(func.func(iree-gpu-pack-to-intrinsics, canonicalize, cse))' --split-input-file | FileCheck %s
// RUN: iree-opt %s --mlir-print-local-scope --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-pack-to-intrinsics, canonicalize, cse))' --split-input-file | FileCheck %s

#config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}>
module {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ iree_compiler_cc_library(
"CombineBarrierRegions.cpp",
"ConcretizeMmaShapes.cpp",
"DistributeMmaToLanes.cpp",
"FuseAndHoistParallelLoops.cpp",
"LowerIREEGPUOps.cpp",
"PackToIntrinsics.cpp",
"Passes.cpp",
"Transforms.cpp",
"UnrollToIntrinsics.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ iree_cc_library(
"CombineBarrierRegions.cpp"
"ConcretizeMmaShapes.cpp"
"DistributeMmaToLanes.cpp"
"FuseAndHoistParallelLoops.cpp"
"LowerIREEGPUOps.cpp"
"PackToIntrinsics.cpp"
"Passes.cpp"
"Transforms.cpp"
"UnrollToIntrinsics.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,6 @@ def DistributeMmaToLanesPass :
];
}

def FuseAndHoistParallelLoopsPass :
InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
let summary = "Greedily fuses and hoists parallel loops.";
let dependentDialects = [
"::mlir::affine::AffineDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::bufferization::BufferizationDialect"
];
}

def LowerIREEGPUOpsPass :
InterfacePass<"iree-gpu-lower-ops", "mlir::FunctionOpInterface"> {
let summary = "Post bufferization lowerings of iree_gpu ops before late lowerings";
Expand All @@ -62,15 +52,6 @@ def LowerIREEGPUOpsPass :
];
}

def PackToIntrinsicsPass :
InterfacePass<"iree-gpu-pack-to-intrinsics", "mlir::FunctionOpInterface"> {
let summary = "Packs matmul like operations and converts to iree_gpu.multi_mma";
let dependentDialects = [
"::mlir::tensor::TensorDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
];
}

def UnrollToIntrinsicsPass :
InterfacePass<"iree-gpu-unroll-to-intrinsics", "mlir::FunctionOpInterface"> {
let summary = "Unrolls iree_gpu.multi_mma ops to their inner vector size.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ iree_lit_test_suite(
"combine_barrier_regions.mlir",
"concretize_mma_shapes.mlir",
"distribute_mma_to_lanes.mlir",
"fuse_and_hoist_forall.mlir",
"pack_to_intrinsics.mlir",
],
include = ["*.mlir"],
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ iree_lit_test_suite(
"combine_barrier_regions.mlir"
"concretize_mma_shapes.mlir"
"distribute_mma_to_lanes.mlir"
"fuse_and_hoist_forall.mlir"
"pack_to_intrinsics.mlir"
TOOLS
FileCheck
iree-opt
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,

// Step 1. Promote matmul operands and pack to intrinsic shapes.
funcPassManager.addPass(createGPUPromoteMatmulOperandsPass());
funcPassManager.addPass(IREE::GPU::createPackToIntrinsicsPass());
funcPassManager.addPass(createGPUPackToIntrinsicsPass());
// Decompose packs and unpacks that are at the function boundary.
funcPassManager.addPass(createDecomposeBoundaryPackUnPackOpsPass());

Expand Down Expand Up @@ -421,7 +421,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
}

// Step 5. Greedily fuse parallel loops and hoist from serial loops.
funcPassManager.addPass(IREE::GPU::createFuseAndHoistParallelLoopsPass());
funcPassManager.addPass(createGPUFuseAndHoistParallelLoopsPass());
funcPassManager.addPass(createGPUGreedilyDistributeToThreadsPass());
funcPassManager.addPass(createTileLargeTensorsPass());
funcPassManager.addPass(createCanonicalizerPass());
Expand Down

0 comments on commit 4a5187d

Please sign in to comment.