forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU] Add a pass to convert accumulating GEMMs to GEMMs (iree-org#19587)
Converts dispatches with accumulating GEMMs that are doing in place read/write to GEMM + elementwise add. This is needed for the TileAndFuse path until we find a more permanent fix for iree-org#19546 --------- Signed-off-by: Nirvedh Meshram <[email protected]>
- Loading branch information
1 parent
550d88e
commit 80cbf6b
Showing
8 changed files
with
219 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
//===- ConvertAccGEMMtoGEMMpass.cpp ----------------------------------===// | ||
// | ||
// Converts Accumulating GEMM to GEMM + elementwise add. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tensor/Utils/Utils.h" | ||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||
#include "mlir/IR/AffineMap.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/WalkPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_CONVERTACCGEMMTOGEMMPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
struct ConvertAccGEMMtoGEMM final | ||
: OpInterfaceRewritePattern<linalg::LinalgOp> { | ||
using OpInterfaceRewritePattern::OpInterfaceRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, | ||
PatternRewriter &rewriter) const override { | ||
if (!linalg::isaContractionOpInterface(linalgOp) && | ||
!isa<linalg::ConvolutionOpInterface>(*linalgOp)) { | ||
return failure(); | ||
} | ||
if (!linalgOp.hasPureTensorSemantics()) | ||
return failure(); | ||
|
||
// Nothing to do if the output tensor operand is already a fill op. | ||
SmallVector<OpOperand *> outputOperands; | ||
if (!linalgOp.hasPureBufferSemantics()) { | ||
outputOperands = llvm::to_vector( | ||
llvm::make_pointer_range(linalgOp.getDpsInitsMutable())); | ||
} | ||
|
||
Value outputOperand = outputOperands.front()->get(); | ||
|
||
auto outsDefiningOp = | ||
outputOperand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>(); | ||
if (!outsDefiningOp) { | ||
// If not DispatchTensorLoadOp then do nothing. | ||
return failure(); | ||
} | ||
auto outputType = cast<RankedTensorType>(outputOperand.getType()); | ||
if (!outputType.getElementType().isIntOrFloat()) | ||
return failure(); | ||
auto elementType = outputType.getElementType(); | ||
|
||
Location loc = linalgOp.getLoc(); | ||
|
||
// Check if the output tensor access is a projected permutation | ||
if (!linalgOp.getMatchingIndexingMap(outputOperands.front()) | ||
.isProjectedPermutation()) { | ||
return rewriter.notifyMatchFailure( | ||
linalgOp, "Output indexing map must be a projected permutation."); | ||
} | ||
|
||
int64_t outputRank = outputType.getRank(); | ||
SmallVector<utils::IteratorType> iterators(outputRank, | ||
utils::IteratorType::parallel); | ||
SmallVector<AffineMap> maps(3, rewriter.getMultiDimIdentityMap(outputRank)); | ||
|
||
// Create a zero tensor as the new output tensor operand to the Linalg | ||
// contraction op. | ||
SmallVector<OpFoldResult> mixedSizes = | ||
tensor::getMixedSizes(rewriter, loc, outputOperand); | ||
auto initOp = | ||
rewriter.create<tensor::EmptyOp>(loc, mixedSizes, elementType); | ||
Value zero = rewriter.create<arith::ConstantOp>( | ||
loc, rewriter.getZeroAttr(elementType)); | ||
Value fill = | ||
rewriter.create<linalg::FillOp>(loc, zero, initOp.getResult()).result(); | ||
|
||
// Update the contraction op to use the new zero tensor as output operand. | ||
rewriter.modifyOpInPlace(linalgOp, | ||
[&]() { linalgOp.setDpsInitOperand(0, fill); }); | ||
|
||
// Create a generic op to add back the original output tensor operand. | ||
rewriter.setInsertionPointAfter(linalgOp); | ||
auto genericOp = rewriter.create<linalg::GenericOp>( | ||
loc, outputType, ValueRange{linalgOp->getResult(0), outputOperand}, | ||
fill, maps, iterators, | ||
[&](OpBuilder &b, Location nestedLoc, ValueRange args) { | ||
Value result; | ||
if (llvm::isa<FloatType>(elementType)) { | ||
result = b.create<arith::AddFOp>(nestedLoc, args[0], args[1]); | ||
} else { | ||
result = b.create<arith::AddIOp>(nestedLoc, args[0], args[1]); | ||
} | ||
b.create<linalg::YieldOp>(nestedLoc, result); | ||
}); | ||
linalgOp->getResult(0).replaceAllUsesExcept(genericOp->getResult(0), | ||
genericOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ConvertAccGEMMToGEMMPass final | ||
: impl::ConvertAccGEMMToGEMMPassBase<ConvertAccGEMMToGEMMPass> { | ||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
patterns.add<ConvertAccGEMMtoGEMM>(&getContext()); | ||
walkAndApplyPatterns(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace | ||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// RUN: iree-opt --split-input-file --iree-convert-accgemm-to-gemm %s | FileCheck %s | ||
|
||
#pipeline_layout = #hal.pipeline.layout<bindings = [ | ||
#hal.pipeline.binding<storage_buffer> | ||
]> | ||
|
||
func.func @accumulate_gemm(%1 : tensor<512x128xi8>, %2 : tensor<512x128xi8>) { | ||
%c0 = arith.constant 0 : index | ||
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> | ||
%4 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> -> tensor<512x512xi32> | ||
%5 = linalg.generic { | ||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, | ||
affine_map<(d0, d1, d2) -> (d1, d2)>, | ||
affine_map<(d0, d1, d2) -> (d0, d1)>], | ||
iterator_types = ["parallel", "parallel", "reduction"]} | ||
ins(%1, %2 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%4 : tensor<512x512xi32>) { | ||
^bb0(%in: i8, %in_0: i8, %out: i32): | ||
%6 = arith.extsi %in : i8 to i32 | ||
%7 = arith.extsi %in_0 : i8 to i32 | ||
%8 = arith.muli %6, %7 : i32 | ||
%9 = arith.addi %out, %8 : i32 | ||
linalg.yield %9 : i32 | ||
} -> tensor<512x512xi32> | ||
flow.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @accumulate_gemm | ||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 | ||
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<512x512xi32> | ||
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<512x512xi32>) -> tensor<512x512xi32> | ||
// CHECK: %[[GEMM:.+]] = linalg.generic {{.*}} outs(%[[FILL]] : tensor<512x512xi32>) { | ||
// CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[GEMM]] | ||
// CHECK: flow.dispatch.tensor.store %[[ADD]] | ||
|
||
|
||
// ----- | ||
|
||
#pipeline_layout = #hal.pipeline.layout<bindings = [ | ||
#hal.pipeline.binding<storage_buffer> | ||
]> | ||
|
||
func.func @acc_conv_nchw(%1 : tensor<1x64x58x58xf32>, %2 : tensor<64x64x3x3xf32>) { | ||
%c0 = arith.constant 0 : index | ||
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>> | ||
%4 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>> -> tensor<1x64x56x56xf32> | ||
%5 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} | ||
ins(%1, %2 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%4 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32> | ||
flow.dispatch.tensor.store %5, %3, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : tensor<1x64x56x56xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @acc_conv_nchw | ||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x64x56x56xf32> | ||
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[EMPTY]] : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32> | ||
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nchw_fchw {{.*}} outs(%[[FILL]] : tensor<1x64x56x56xf32>) | ||
// CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[CONV]] | ||
// CHECK: flow.dispatch.tensor.store %[[ADD]] | ||
|
||
// ----- | ||
|
||
#pipeline_layout = #hal.pipeline.layout<bindings = [ | ||
#hal.pipeline.binding<storage_buffer> | ||
]> | ||
|
||
|
||
func.func @nonacc_gemm(%1 : tensor<512x128xi8>, %2 : tensor<512x128xi8>) { | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c0 = arith.constant 0 : index | ||
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<512x512xi32>> | ||
%empty = tensor.empty() : tensor<512x512xi32> | ||
%fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<512x512xi32>) -> tensor<512x512xi32> | ||
%5 = linalg.matmul_transpose_b | ||
ins(%1, %2 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%fill : tensor<512x512xi32>) -> tensor<512x512xi32> | ||
flow.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<writeonly:tensor<512x512xi32>> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @nonacc_gemm | ||
// CHECK: linalg.matmul_transpose_b | ||
// CHECK-NOT: linalg.generic |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters