-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GPU launch constants inlining (#888)
Adds a pass that inlines constants into gpu.launch body. This reduces the number of GPU kernel parameters after kernel outlining and allows for further constant propagation within the kernel. --------- Co-authored-by: Renato Golin <[email protected]>
- Loading branch information
Showing
6 changed files
with
192 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
//===- GpuInlineConstants.cpp ------------------------------------*- C++-*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "TPP/Passes.h" | ||
|
||
#include "mlir/Conversion/Passes.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/IR/Dialect.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "mlir/Transforms/RegionUtils.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::tpp; | ||
|
||
namespace mlir { | ||
namespace tpp { | ||
#define GEN_PASS_DEF_GPUINLINECONSTANTS | ||
#include "TPP/Passes.h.inc" | ||
} // namespace tpp | ||
} // namespace mlir | ||
|
||
namespace { | ||
|
||
// Inlines constants into GPU launch body. | ||
struct InlineConstantsIntoGPULaunch : public OpRewritePattern<gpu::LaunchOp> { | ||
using OpRewritePattern<gpu::LaunchOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(gpu::LaunchOp launchOp, | ||
PatternRewriter &rewriter) const override { | ||
Region &launchOpBody = launchOp.getBody(); | ||
|
||
// Identify values defined outside of the launch operation. | ||
SetVector<Value> aboveVals; | ||
getUsedValuesDefinedAbove(launchOpBody, aboveVals); | ||
|
||
// Gather operations representing constants. | ||
SetVector<Operation *> constantOps; | ||
for (auto val : aboveVals) { | ||
auto *op = val.getDefiningOp(); | ||
// TODO: Add more constant representations. | ||
if (op && isa<arith::ConstantOp>(op)) | ||
constantOps.insert(op); | ||
} | ||
|
||
// Clone the constants into the gpu.launch body. | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointToStart(&launchOpBody.front()); | ||
|
||
for (auto *op : constantOps) { | ||
auto *clonedOp = rewriter.clone(*op); | ||
|
||
// Replace uses within the body with the inlined values. | ||
for (auto [oldVal, newVal] : | ||
llvm::zip_equal(op->getResults(), clonedOp->getResults())) { | ||
replaceAllUsesInRegionWith(oldVal, newVal, launchOpBody); | ||
} | ||
} | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
void populateGpuInlineConstantsPatterns(RewritePatternSet &patterns) { | ||
patterns.add<InlineConstantsIntoGPULaunch>(patterns.getContext()); | ||
} | ||
|
||
struct GpuInlineConstants | ||
: public tpp::impl::GpuInlineConstantsBase<GpuInlineConstants> { | ||
using GpuInlineConstantsBase::GpuInlineConstantsBase; | ||
|
||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
populateGpuInlineConstantsPatterns(patterns); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// RUN: tpp-opt %s -gpu-inline-constants -split-input-file | FileCheck %s | ||
|
||
// RUN: tpp-opt %s -gpu-inline-constants -gpu-kernel-outlining -canonicalize -cse -split-input-file | \ | ||
// RUN: FileCheck %s --check-prefix=OUTLINED | ||
|
||
func.func @scalar_constants(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c8 = arith.constant 8 : index | ||
%c16 = arith.constant 16 : index | ||
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { | ||
scf.for %i = %c0 to %c8 step %c1 { | ||
scf.for %j = %c0 to %c16 step %c1 { | ||
%0 = memref.load %arg0[%i, %j] : memref<8x16xf16> | ||
memref.store %0, %arg1[%i, %j] : memref<8x16xf16> | ||
} | ||
} | ||
gpu.terminator | ||
} | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @scalar_constants | ||
// CHECK: arith.constant 1 : index | ||
// CHECK: gpu.launch | ||
// CHECK-DAG: arith.constant 0 : index | ||
// CHECK-DAG: arith.constant 1 : index | ||
// CHECK-DAG: arith.constant 8 : index | ||
// CHECK-DAG: arith.constant 16 : index | ||
|
||
// OUTLINED-LABEL: func.func @scalar_constants | ||
// OUTLINED-SAME: %[[arg0:.+]]: memref<8x16xf16>, %[[arg1:.+]]: memref<8x16xf16> | ||
// OUTLINED: arith.constant 1 : index | ||
// OUTLINED: gpu.launch_func{{.*}}args(%[[arg0]] : memref<8x16xf16>, %[[arg1]] : memref<8x16xf16>) | ||
// OUTLINED: gpu.module | ||
// OUTLINED-LABEL: gpu.func @scalar_constants_kernel | ||
// OUTLINED-DAG: arith.constant 0 : index | ||
// OUTLINED-DAG: arith.constant 1 : index | ||
// OUTLINED-DAG: arith.constant 8 : index | ||
// OUTLINED-DAG: arith.constant 16 : index | ||
|
||
// ----- | ||
|
||
func.func @dense_constant(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%v0 = arith.constant dense<0.0> : vector<8x16xf16> | ||
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { | ||
%0 = vector.load %arg0[%c0, %c0] : memref<8x16xf16>, vector<8x16xf16> | ||
%1 = arith.maximumf %0, %v0 : vector<8x16xf16> | ||
vector.store %1, %arg1[%c0, %c0] : memref<8x16xf16>, vector<8x16xf16> | ||
gpu.terminator | ||
} | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @dense_constant | ||
// CHECK: arith.constant 1 : index | ||
// CHECK: gpu.launch | ||
// CHECK-DAG: arith.constant 0 : index | ||
// CHECK-DAG: arith.constant dense<0.000000e+00> : vector<8x16xf16> | ||
|
||
// OUTLINED-LABEL: func.func @dense_constant | ||
// OUTLINED-SAME: %[[arg0:.+]]: memref<8x16xf16>, %[[arg1:.+]]: memref<8x16xf16> | ||
// OUTLINED: arith.constant 1 : index | ||
// OUTLINED: gpu.launch_func{{.*}}args(%[[arg0]] : memref<8x16xf16>, %[[arg1]] : memref<8x16xf16>) | ||
// OUTLINED: gpu.module | ||
// OUTLINED-LABEL: gpu.func @dense_constant_kernel | ||
// OUTLINED-DAG: arith.constant 0 : index | ||
// OUTLINED-DAG: arith.constant dense<0.000000e+00> : vector<8x16xf16> |