forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LLVMCPU] Tile root and fuse consumer producer pass (iree-org#17804)
This is a draft patch to get helpful insights on consumer fusion. -- This patch tiles the root op and does producer-consumer fusion greedily. I want to iterate on this patch and get helpful inputs because consumer fusion modifies and replaces the consumers in place, unlike producer fusion, where you can get the producers from the tiledops. To upstream the `tileProducerAndFuseConsumerAPI,` we'll need both original and tiled ops.
- Loading branch information
Showing
8 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
212 changes: 212 additions & 0 deletions
212
compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" | ||
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h" | ||
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" | ||
#include "iree/compiler/Codegen/LLVMCPU/Passes.h" | ||
#include "iree/compiler/Codegen/LLVMCPU/Utils.h" | ||
#include "iree/compiler/Codegen/Utils/CPUUtils.h" | ||
#include "llvm/Support/CommandLine.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h" | ||
#include "mlir/Dialect/SCF/Transforms/Patterns.h" | ||
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" | ||
#include "mlir/Dialect/SCF/Transforms/Transforms.h" | ||
#include "mlir/IR/Iterators.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-llvmcpu-tile-root-and-fuse-producers-consumers" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
namespace { | ||
|
||
/// Implementation of tile root and fuse producers and consumers greedily. | ||
static LogicalResult tileRootAndFuseProducerConsumerUsingSCF( | ||
RewriterBase &rewriter, TilingInterface root, | ||
const scf::SCFTileAndFuseOptions &options) { | ||
|
||
// This transformation is only valid for ops that return values (i.e. not | ||
// valid to use with operations that have memref operands). | ||
if (!root->getNumResults()) { | ||
return rewriter.notifyMatchFailure( | ||
root, "invalid pattern for op with no results"); | ||
} | ||
|
||
// 1. Tile root op and Fuse Producers. | ||
FailureOr<scf::SCFTileAndFuseResult> tiledResults = | ||
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options); | ||
|
||
if (failed(tiledResults)) { | ||
return rewriter.notifyMatchFailure( | ||
root, "failed to tile root and fuse producers"); | ||
} | ||
|
||
// 2. Replace the producers with the tiled verison. | ||
SmallVector<Operation *> opsToReplace = {root}; | ||
llvm::append_range(opsToReplace, tiledResults->fusedProducers); | ||
for (Operation *toReplace : opsToReplace) { | ||
for (OpResult res : toReplace->getResults()) | ||
if (auto replacement = tiledResults->replacements.lookup(res)) { | ||
rewriter.replaceAllUsesWith(res, replacement); | ||
} | ||
|
||
if (toReplace->use_empty()) { | ||
rewriter.eraseOp(toReplace); | ||
} | ||
} | ||
|
||
// 3. Typically, the consumers of the tiled operation are slices of the | ||
// results of the tiled operation. These are expressed in IR using | ||
// `tensor.insert_slice` operations, whose outputs are the operands of the | ||
// untiled operation. Create a worklist of these `tensor.insert_siices` | ||
// operations. If the consumers of the source of the `tensor.insert_slices` | ||
// can be tiled such that the tiled value is generated in-place, that | ||
// effectively tiles + fuses the operations. | ||
auto addCandidateSlices = [](Operation *fusedOp, | ||
std::queue<tensor::InsertSliceOp> &candidates) { | ||
for (auto *userOp : fusedOp->getResults().getUsers()) { | ||
if (auto sliceOp = llvm::dyn_cast<tensor::InsertSliceOp>(userOp)) { | ||
candidates.push(sliceOp); | ||
} | ||
} | ||
}; | ||
|
||
// Collect the candidate slices which can be potential consumers that can be | ||
// fused. | ||
std::queue<tensor::InsertSliceOp> candidates; | ||
addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates); | ||
|
||
while (!candidates.empty()) { | ||
|
||
// Traverse the slices in BFS fashion. | ||
tensor::InsertSliceOp candidateSliceOp = candidates.front(); | ||
candidates.pop(); | ||
|
||
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult = | ||
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp); | ||
if (failed(fusedResult)) { | ||
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: " | ||
<< candidateSliceOp << "\n"); | ||
continue; | ||
} | ||
|
||
// Replace the original consumer operation with the tiled implementation. | ||
rewriter.replaceOp(fusedResult->origConsumerOperand->getOwner(), | ||
fusedResult->tiledOps.front()); | ||
|
||
// The result of the fused conumers might themselved be slices of | ||
// values produced by operations that implement the `TilingInterface`. | ||
// Add these operations to the worklist. | ||
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(), | ||
candidates); | ||
} | ||
return success(); | ||
} | ||
|
||
static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter, | ||
TilingInterface rootOp, | ||
int64_t tilingLevel) { | ||
|
||
SmallVector<OpFoldResult> tileSizes = | ||
getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel, | ||
rootOp); | ||
int64_t numLoops = rootOp.getLoopIteratorTypes().size(); | ||
if (tileSizes.size() > numLoops) | ||
return failure(); | ||
|
||
scf::SCFTilingOptions tilingOptions; | ||
tilingOptions.setTileSizes(tileSizes); | ||
|
||
scf::SCFTileAndFuseOptions tileAndFuseOptions; | ||
tileAndFuseOptions.setTilingOptions(tilingOptions); | ||
|
||
return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp, | ||
tileAndFuseOptions); | ||
} | ||
|
||
/// This pass starts with the first TilingInterface operation that has | ||
/// lowering_config attribute, tiles the op and fuses its consumers and | ||
/// producers recursively. The `tilingLevel` must be specified. It picks the | ||
/// `tilingLevel`-th list as tiling sizes from lowering_config. | ||
struct LLVMCPUTileRootAndFuseProducerConsumer | ||
: LLVMCPUTileRootAndFuseProducerConsumerBase< | ||
LLVMCPUTileRootAndFuseProducerConsumer> { | ||
LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel = -1) { | ||
this->tilingLevel.setValue(tilingLevel); | ||
} | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<arith::ArithDialect, affine::AffineDialect, | ||
linalg::LinalgDialect, scf::SCFDialect, | ||
tensor::TensorDialect>(); | ||
} | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() { | ||
MLIRContext *context = &getContext(); | ||
auto funcOp = getOperation(); | ||
|
||
IRRewriter rewriter(funcOp); | ||
|
||
SmallVector<Operation *> computeOps = getComputeOps(funcOp); | ||
FailureOr<Operation *> rootOp = getRootOperation(computeOps); | ||
|
||
if (failed(rootOp)) { | ||
funcOp.emitError() << "not able to find the root operation\n"; | ||
return signalPassFailure(); | ||
} | ||
|
||
IREE::Codegen::LoweringConfigAttrInterface loweringConfig = | ||
getLoweringConfig(rootOp.value()); | ||
if (!loweringConfig) { | ||
funcOp.emitError() << "not able to find the lowering config\n"; | ||
return signalPassFailure(); | ||
} | ||
|
||
if (!loweringConfig.hasTilingLevel(tilingLevel)) { | ||
funcOp.emitError() | ||
<< "not able to find the lowering config with the tiling level " | ||
<< tilingLevel.getValue() << "\n"; | ||
return signalPassFailure(); | ||
} | ||
|
||
if (failed(tileRootAndFuseProducerConsumer( | ||
rewriter, dyn_cast<TilingInterface>(rootOp.value()), | ||
tilingLevel.getValue()))) { | ||
funcOp.emitError() << "tiling of level " << tilingLevel.getValue() | ||
<< " failed\n"; | ||
return signalPassFailure(); | ||
} | ||
|
||
RewritePatternSet patterns = | ||
linalg::getLinalgTilingCanonicalizationPatterns(context); | ||
scf::populateSCFForLoopCanonicalizationPatterns(patterns); | ||
tensor::populateFoldTensorEmptyPatterns(patterns); | ||
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); | ||
// Pull in tensor dialect canonicalization patterns to fold tensor.cast | ||
// into producers when possible. | ||
context->getLoadedDialect<tensor::TensorDialect>() | ||
->getCanonicalizationPatterns(patterns); | ||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { | ||
LLVM_DEBUG(llvm::dbgs() << "----- cleanup failed -----\n"); | ||
return signalPassFailure(); | ||
} | ||
} | ||
} // namespace | ||
|
||
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> | ||
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) { | ||
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(tilingLevel); | ||
} | ||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), canonicalize)" --split-input-file %s | FileCheck %s | ||
|
||
#config1 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]> | ||
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> | ||
func.func @mmt4d_bias_relu(%arg0: tensor<?x?x16x1xf32>, %arg1: tensor<?x?x16x1xf32>, %arg2: tensor<?x16xf32>) -> tensor<?x?x16x16xf32> { | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%dim = tensor.dim %arg0, %c0 : tensor<?x?x16x1xf32> | ||
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x16x1xf32> | ||
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32> | ||
%1 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32> | ||
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32> | ||
%3 = linalg.mmt4d {lowering_config = #config1} ins(%arg0, %arg1 : tensor<?x?x16x1xf32>, tensor<?x?x16x1xf32>) outs(%2 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32> | ||
%4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3, %arg2 : tensor<?x?x16x16xf32>, tensor<?x16xf32>) outs(%1 : tensor<?x?x16x16xf32>) { | ||
^bb0(%in: f32, %in_1: f32, %out: f32): | ||
%5 = arith.addf %in, %in_1 : f32 | ||
%6 = arith.maximumf %5, %cst : f32 | ||
linalg.yield %6 : f32 | ||
} -> tensor<?x?x16x16xf32> | ||
return %4 : tensor<?x?x16x16xf32> | ||
} | ||
// CHECK: func.func @mmt4d_bias_relu( | ||
// CHECK: scf.for | ||
// CHECK-SAME: { | ||
// CHECK: linalg.fill | ||
// CHECK: linalg.mmt4d | ||
// CHECK: linalg.generic | ||
// CHECK: } | ||
|
||
// ----- | ||
|
||
#config2 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]> | ||
func.func @quantized_matmul() { | ||
%c2995200 = arith.constant 2995200 : index | ||
%c2994688 = arith.constant 2994688 : index | ||
%c2994176 = arith.constant 2994176 : index | ||
%c176128 = arith.constant 176128 : index | ||
%c88064 = arith.constant 88064 : index | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2995200) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>> | ||
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2994688) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> | ||
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2994176) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> | ||
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c176128) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x688x128x16x1xi8>> | ||
%4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c88064) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>> | ||
%5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>> | ||
%6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x11008x64xf32>> | ||
%7 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>> -> tensor<2x4x128x16x1xi8> | ||
%8 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32> | ||
%9 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32> | ||
%10 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0, 0], sizes = [2, 688, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x688x128x16x1xi8>> -> tensor<2x688x128x16x1xi8> | ||
%11 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 688, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>> -> tensor<2x688x16xf32> | ||
%12 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [2, 688, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>> -> tensor<2x688x16xf32> | ||
%13 = tensor.empty() : tensor<2x4x128x16x1xf32> | ||
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2x4x128x16x1xi8>, tensor<2x4x16xf32>, tensor<2x4x16xf32>) outs(%13 : tensor<2x4x128x16x1xf32>) { | ||
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32): | ||
%21 = arith.extui %in : i8 to i32 | ||
%22 = arith.uitofp %21 : i32 to f32 | ||
%23 = arith.subf %22, %in_1 : f32 | ||
%24 = arith.mulf %23, %in_0 : f32 | ||
linalg.yield %24 : f32 | ||
} -> tensor<2x4x128x16x1xf32> | ||
%15 = tensor.empty() : tensor<2x688x128x16x1xf32> | ||
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%10, %11, %12 : tensor<2x688x128x16x1xi8>, tensor<2x688x16xf32>, tensor<2x688x16xf32>) outs(%15 : tensor<2x688x128x16x1xf32>) { | ||
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32): | ||
%21 = arith.extui %in : i8 to i32 | ||
%22 = arith.uitofp %21 : i32 to f32 | ||
%23 = arith.subf %22, %in_1 : f32 | ||
%24 = arith.mulf %23, %in_0 : f32 | ||
linalg.yield %24 : f32 | ||
} -> tensor<2x688x128x16x1xf32> | ||
%17 = tensor.empty() : tensor<2x4x688x16x16xf32> | ||
%18 = linalg.fill ins(%cst : f32) outs(%17 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32> | ||
%19 = linalg.batch_mmt4d {lowering_config = #config2} ins(%14, %16 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%18 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32> | ||
%20 = tensor.empty() : tensor<2x11008x64xf32> | ||
%unpack = tensor.unpack %19 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %20 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32> | ||
flow.dispatch.tensor.store %unpack, %6, offsets = [0, 0, 0], sizes = [2, 11008, 64], strides = [1, 1, 1] : tensor<2x11008x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x11008x64xf32>> | ||
return | ||
} | ||
// CHECK: func.func @quantized_matmul( | ||
// CHECK: scf.for | ||
// CHECK-SAME: { | ||
// CHECK: linalg.generic | ||
// CHECK: linalg.generic | ||
// CHECK: linalg.fill | ||
// CHECK: linalg.batch_mmt4d | ||
// CHECK: tensor.unpack | ||
// CHECK: } |