Skip to content

Commit

Permalink
GPU launch constants inlining (#888)
Browse files Browse the repository at this point in the history
Adds a pass that inlines constants into gpu.launch body.

This reduces the number of GPU kernel parameters after kernel outlining
and allows for further constant propagation within the kernel.

---------

Co-authored-by: Renato Golin <[email protected]>
  • Loading branch information
2 people authored and nhasabni committed Mar 14, 2024
1 parent 2a09756 commit 1c6c4ba
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 13 deletions.
11 changes: 11 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -483,4 +483,15 @@ def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> {
let dependentDialects = ["affine::AffineDialect", "scf::SCFDialect"];
}

def GpuInlineConstants : Pass<"gpu-inline-constants", "func::FuncOp"> {
let summary = "Inlines constants into GPU launch.";
let description = [{
Inline constants into GPU launch body to reduce number of parameters
and allow further constant propagation after kernel outlining.
The pass should be used just before GPU kernel outlining.
}];
let dependentDialects = ["gpu::GPUDialect",
"arith::ArithDialect"];
}

#endif // TPP_DIALECT_TPP_PASSES
1 change: 1 addition & 0 deletions lib/TPP/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_mlir_library(TPPGPU
GpuVulkanAbi.cpp
LinalgToGpu.cpp
GpuDataTransfer.cpp
GpuInlineConstants.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/GPU/GpuConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ struct GpuConversion : public tpp::impl::GpuConversionBase<GpuConversion>,
pm.addNestedPass<func::FuncOp>(createCleanup());

// Create GPU kernels.
pm.addNestedPass<func::FuncOp>(createGpuInlineConstants());
pm.addPass(createGpuKernelOutliningPass());

// Generic cleanup.
Expand Down
88 changes: 88 additions & 0 deletions lib/TPP/GPU/GpuInlineConstants.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//===- GpuInlineConstants.cpp ------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"

using namespace mlir;
using namespace mlir::tpp;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_GPUINLINECONSTANTS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

// Inlines constants into GPU launch body.
struct InlineConstantsIntoGPULaunch : public OpRewritePattern<gpu::LaunchOp> {
using OpRewritePattern<gpu::LaunchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(gpu::LaunchOp launchOp,
PatternRewriter &rewriter) const override {
Region &launchOpBody = launchOp.getBody();

// Identify values defined outside of the launch operation.
SetVector<Value> aboveVals;
getUsedValuesDefinedAbove(launchOpBody, aboveVals);

// Gather operations representing constants.
SetVector<Operation *> constantOps;
for (auto val : aboveVals) {
auto *op = val.getDefiningOp();
// TODO: Add more constant representations.
if (op && isa<arith::ConstantOp>(op))
constantOps.insert(op);
}

// Clone the constants into the gpu.launch body.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&launchOpBody.front());

for (auto *op : constantOps) {
auto *clonedOp = rewriter.clone(*op);

// Replace uses within the body with the inlined values.
for (auto [oldVal, newVal] :
llvm::zip_equal(op->getResults(), clonedOp->getResults())) {
replaceAllUsesInRegionWith(oldVal, newVal, launchOpBody);
}
}

return success();
}
};

void populateGpuInlineConstantsPatterns(RewritePatternSet &patterns) {
patterns.add<InlineConstantsIntoGPULaunch>(patterns.getContext());
}

struct GpuInlineConstants
: public tpp::impl::GpuInlineConstantsBase<GpuInlineConstants> {
using GpuInlineConstantsBase::GpuInlineConstantsBase;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateGpuInlineConstantsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace
34 changes: 21 additions & 13 deletions test/GPU/gpu-conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ func.func @identity_with_bcast(%arg0: memref<5x1xf32>, %arg1: memref<5x6xf32>) {
// CHECK: gpu.launch_func @identity_with_bcast_kernel::@identity_with_bcast_kernel
// CHECK: gpu.module @identity_with_bcast_kernel
// CHECK-LABEL: gpu.func @identity_with_bcast_kernel
// CHECK-SAME: %[[ARG0:.+]]: memref<5x1xf32>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: memref<5x6xf32>
// CHECK-SAME: %[[ARG0:.+]]: memref<5x1xf32>, %[[ARG2:.+]]: memref<5x6xf32>
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[X:.+]] = gpu.block_id x
// CHECK-NEXT: %[[Y:.+]] = gpu.block_id y
// CHECK: %[[L:.+]] = memref.load %arg0[%[[X]], %[[ARG1]]] : memref<5x1xf32>
// CHECK: %[[L:.+]] = memref.load %arg0[%[[X]], %[[c0]]] : memref<5x1xf32>
// CHECK: memref.store %[[L]], %[[ARG2]][%[[X]], %[[Y]]] : memref<5x6xf32>
// CHECK: gpu.return

Expand All @@ -160,11 +161,12 @@ func.func @relu(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) {
// CHECK: gpu.launch_func @relu_kernel::@relu_kernel
// CHECK: gpu.module @relu_kernel
// CHECK-LABEL: gpu.func @relu_kernel
// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32>, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: memref<3x3xf32>
// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32>, %[[ARG2:.+]]: memref<3x3xf32>
// CHECK-DAG: %[[c0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[X:.+]] = gpu.block_id x
// CHECK-NEXT: %[[Y:.+]] = gpu.block_id y
// CHECK: %[[L:.+]] = memref.load %[[ARG0]][%[[X]], %[[Y]]] : memref<3x3xf32>
// CHECK: %[[M:.+]] = arith.maximumf %[[L]], %[[ARG1]] : f32
// CHECK: %[[M:.+]] = arith.maximumf %[[L]], %[[c0]] : f32
// CHECK: memref.store %[[M]], %[[ARG2]][%[[X]], %[[Y]]] : memref<3x3xf32>
// CHECK: gpu.return

Expand All @@ -181,10 +183,11 @@ func.func @zero(%arg0: memref<3x3xf32>) {
// CHECK: gpu.launch_func @zero_kernel::@zero_kernel
// CHECK: gpu.module @zero_kernel
// CHECK-LABEL: gpu.func @zero_kernel
// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: memref<3x3xf32>
// CHECK-SAME: %[[ARG1:.+]]: memref<3x3xf32>
// CHECK-DAG: %[[c0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[X:.+]] = gpu.block_id x
// CHECK-NEXT: %[[Y:.+]] = gpu.block_id y
// CHECK: memref.store %[[ARG0]], %[[ARG1]][%[[X]], %[[Y]]] : memref<3x3xf32>
// CHECK: memref.store %[[c0]], %[[ARG1]][%[[X]], %[[Y]]] : memref<3x3xf32>
// CHECK: gpu.return

// -----
Expand Down Expand Up @@ -221,14 +224,17 @@ func.func @brgemm(%arg0: memref<2x3x4xf32>, %arg1: memref<2x4x3xf32>, %arg2: mem
// CHECK: gpu.launch_func @brgemm_kernel::@brgemm_kernel
// CHECK: gpu.module @brgemm_kernel
// CHECK-LABEL: gpu.func @brgemm_kernel
// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32>, %[[ARG1:.+]]: memref<2x3x4xf32>, %[[ARG2:.+]]: memref<2x4x3xf32>,
// CHECK-SAME: %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index, %[[ARG6:.+]]: index
// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32>, %[[ARG1:.+]]: memref<2x3x4xf32>, %[[ARG2:.+]]: memref<2x4x3xf32>
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
// CHECK: %[[X:.+]] = gpu.block_id x
// CHECK-NEXT: %[[Y:.+]] = gpu.block_id y
// CHECK: %[[C:.+]] = memref.load %[[ARG0]][%[[X]], %[[Y]]] : memref<3x3xf32>
// CHECK: %[[R:.+]] = scf.for %[[ARG7:.+]] = %[[ARG3]] to %[[ARG6]] step %[[ARG5]]
// CHECK: %[[R:.+]] = scf.for %[[ARG7:.+]] = %[[c0]] to %[[c2]] step %[[c1]]
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[C]])
// CHECK: %{{.+}} = scf.for %[[ARG9:.+]] = %[[ARG3]] to %[[ARG4]] step %[[ARG5]]
// CHECK: %{{.+}} = scf.for %[[ARG9:.+]] = %[[c0]] to %[[c4]] step %[[c1]]
// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]])
// CHECK: %[[A:.+]] = memref.load %[[ARG1]][%[[ARG7]], %[[X]], %[[ARG9]]] : memref<2x3x4xf32>
// CHECK: %[[B:.+]] = memref.load %[[ARG2]][%[[ARG7]], %[[ARG9]], %[[Y]]] : memref<2x4x3xf32>
Expand All @@ -252,12 +258,14 @@ func.func @gemm(%arg0: memref<8x9xf32>, %arg1: memref<9x10xf32>, %arg2: memref<8
// CHECK: gpu.launch_func @gemm_kernel::@gemm_kernel
// CHECK: gpu.module @gemm_kernel
// CHECK-LABEL: gpu.func @gemm_kernel
// CHECK-SAME: %[[ARG0:.+]]: memref<8x10xf32>, %[[ARG1:.+]]: memref<8x9xf32>,
// CHECK-SAME: %[[ARG2:.+]]: memref<9x10xf32>, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index)
// CHECK-SAME: %[[ARG0:.+]]: memref<8x10xf32>, %[[ARG1:.+]]: memref<8x9xf32>, %[[ARG2:.+]]: memref<9x10xf32>
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c9:.+]] = arith.constant 9 : index
// CHECK: %[[X:.+]] = gpu.block_id x
// CHECK-NEXT: %[[Y:.+]] = gpu.block_id y
// CHECK: %[[C:.+]] = memref.load %[[ARG0]][%[[X]], %[[Y]]] : memref<8x10xf32>
// CHECK: %[[R:.+]] = scf.for %[[ARG6:.+]] = %[[ARG3]] to %[[ARG4]] step %[[ARG5]]
// CHECK: %[[R:.+]] = scf.for %[[ARG6:.+]] = %[[c0]] to %[[c9]] step %[[c1]]
// CHECK-SAME: iter_args(%[[ARG7:.+]] = %[[C]])
// CHECK: %[[A:.+]] = memref.load %[[ARG1]][%[[X]], %[[ARG6]]] : memref<8x9xf32>
// CHECK: %[[B:.+]] = memref.load %[[ARG2]][%[[ARG6]], %[[Y]]] : memref<9x10xf32>
Expand Down
70 changes: 70 additions & 0 deletions test/GPU/gpu-inline-constants.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: tpp-opt %s -gpu-inline-constants -split-input-file | FileCheck %s

// RUN: tpp-opt %s -gpu-inline-constants -gpu-kernel-outlining -canonicalize -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=OUTLINED

func.func @scalar_constants(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
scf.for %i = %c0 to %c8 step %c1 {
scf.for %j = %c0 to %c16 step %c1 {
%0 = memref.load %arg0[%i, %j] : memref<8x16xf16>
memref.store %0, %arg1[%i, %j] : memref<8x16xf16>
}
}
gpu.terminator
}
return
}

// CHECK-LABEL: func.func @scalar_constants
// CHECK: arith.constant 1 : index
// CHECK: gpu.launch
// CHECK-DAG: arith.constant 0 : index
// CHECK-DAG: arith.constant 1 : index
// CHECK-DAG: arith.constant 8 : index
// CHECK-DAG: arith.constant 16 : index

// OUTLINED-LABEL: func.func @scalar_constants
// OUTLINED-SAME: %[[arg0:.+]]: memref<8x16xf16>, %[[arg1:.+]]: memref<8x16xf16>
// OUTLINED: arith.constant 1 : index
// OUTLINED: gpu.launch_func{{.*}}args(%[[arg0]] : memref<8x16xf16>, %[[arg1]] : memref<8x16xf16>)
// OUTLINED: gpu.module
// OUTLINED-LABEL: gpu.func @scalar_constants_kernel
// OUTLINED-DAG: arith.constant 0 : index
// OUTLINED-DAG: arith.constant 1 : index
// OUTLINED-DAG: arith.constant 8 : index
// OUTLINED-DAG: arith.constant 16 : index

// -----

func.func @dense_constant(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%v0 = arith.constant dense<0.0> : vector<8x16xf16>
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
%0 = vector.load %arg0[%c0, %c0] : memref<8x16xf16>, vector<8x16xf16>
%1 = arith.maximumf %0, %v0 : vector<8x16xf16>
vector.store %1, %arg1[%c0, %c0] : memref<8x16xf16>, vector<8x16xf16>
gpu.terminator
}
return
}

// CHECK-LABEL: func.func @dense_constant
// CHECK: arith.constant 1 : index
// CHECK: gpu.launch
// CHECK-DAG: arith.constant 0 : index
// CHECK-DAG: arith.constant dense<0.000000e+00> : vector<8x16xf16>

// OUTLINED-LABEL: func.func @dense_constant
// OUTLINED-SAME: %[[arg0:.+]]: memref<8x16xf16>, %[[arg1:.+]]: memref<8x16xf16>
// OUTLINED: arith.constant 1 : index
// OUTLINED: gpu.launch_func{{.*}}args(%[[arg0]] : memref<8x16xf16>, %[[arg1]] : memref<8x16xf16>)
// OUTLINED: gpu.module
// OUTLINED-LABEL: gpu.func @dense_constant_kernel
// OUTLINED-DAG: arith.constant 0 : index
// OUTLINED-DAG: arith.constant dense<0.000000e+00> : vector<8x16xf16>

0 comments on commit 1c6c4ba

Please sign in to comment.