From b06bf6a2213326fa84e73547dd65e6951b0e956c Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Fri, 9 Aug 2024 21:02:37 -0700 Subject: [PATCH] [Codegen] Query `#iree_gpu.target` for shared memory limit (#18184) Signed-off-by: nithinsubbiah --- .../Codegen/Common/GPU/GPUCheckResourceUsage.cpp | 14 +++++--------- .../src/iree/compiler/Codegen/Common/GPU/Passes.h | 2 -- .../src/iree/compiler/Codegen/LLVMGPU/Passes.cpp | 7 +------ .../src/iree/compiler/Codegen/SPIRV/Passes.cpp | 11 ++--------- 4 files changed, 8 insertions(+), 26 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp index 6e37bccc751b..05f4f2fd9ad4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp @@ -50,7 +50,7 @@ static int shapedTypeStaticSize( } /// Returns success if the total shared memory allocation size is less than the -/// limit set by limit. +/// limit. static LogicalResult checkGPUAllocationSize( mlir::FunctionOpInterface funcOp, unsigned limit, std::function getIndexBitwidth) { @@ -93,15 +93,14 @@ class GPUCheckResourceUsagePass final : public impl::GPUCheckResourceUsagePassBase { public: explicit GPUCheckResourceUsagePass( - std::function getSharedMemoryLimit, std::function getIndexBitwidth) - : getSharedMemoryLimit(getSharedMemoryLimit), - getIndexBitwidth(getIndexBitwidth) {} + : getIndexBitwidth(getIndexBitwidth) {} void runOnOperation() override { FunctionOpInterface funcOp = getOperation(); + IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); unsigned limit = - getSharedMemoryLimit ? getSharedMemoryLimit(funcOp) : 64 * 1024; + target ? target.getWgp().getMaxWorkgroupMemoryBytes() : 64 * 1024; if (failed(checkGPUAllocationSize(funcOp, limit, getIndexBitwidth ? getIndexBitwidth @@ -111,7 +110,6 @@ class GPUCheckResourceUsagePass final } private: - std::function getSharedMemoryLimit; std::function getIndexBitwidth; }; @@ -119,10 +117,8 @@ class GPUCheckResourceUsagePass final std::unique_ptr> createGPUCheckResourceUsagePass( - std::function getSharedMemoryLimit, std::function getIndexBitwidth) { - return std::make_unique(getSharedMemoryLimit, - getIndexBitwidth); + return std::make_unique(getIndexBitwidth); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h index 30bcc758dd27..aa5d22bef6c0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h @@ -87,8 +87,6 @@ LogicalResult gpuDistributeSharedMemoryCopy(mlir::FunctionOpInterface funcOp); // get the index size. std::unique_ptr> createGPUCheckResourceUsagePass( - std::function getSharedMemoryLimit = - nullptr, std::function getIndexBitwidth = nullptr); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index f116b97479dd..4517eb2b59a7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -992,13 +992,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, // Run checks on shared memory usage. funcPassManager .addPass([&]() { - auto getSharedMemoryLimit = [](mlir::FunctionOpInterface entryPoint) { - IREE::GPU::TargetAttr target = getGPUTargetAttr(entryPoint); - return target.getWgp().getMaxWorkgroupMemoryBytes(); - }; auto getIndexBitwidth = [](mlir::FunctionOpInterface) { return 64; }; - return createGPUCheckResourceUsagePass(getSharedMemoryLimit, - getIndexBitwidth); + return createGPUCheckResourceUsagePass(getIndexBitwidth); }) // SCF -> CF .addPass(createConvertSCFToCFPass) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index d8b3d8a7e1c2..f8093dd2ddb4 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -183,18 +183,11 @@ static void addMemRefLoweringPasses(OpPassManager &modulePassManager) { .addPass(createPadDynamicAlloc); - // Check to make sure we are not exceeding shared memory usage limit. - auto getSharedMemoryLimit = [](mlir::FunctionOpInterface fn) { - IREE::GPU::TargetAttr target = getGPUTargetAttr(fn); - return target.getWgp().getMaxWorkgroupMemoryBytes(); - }; // TODO: query this from the target. auto getIndexBitwidth = [](mlir::FunctionOpInterface) { return 32; }; funcPassManager - .addPass([&]() { - return createGPUCheckResourceUsagePass(getSharedMemoryLimit, - getIndexBitwidth); - }) + .addPass( + [&]() { return createGPUCheckResourceUsagePass(getIndexBitwidth); }) // Fold load/store from/to subview ops into the original memref when // possible. In SPIR-V we don't use memref descriptor so it's not possible