Skip to content

Commit

Permalink
[Flow] Add pass to bubble and hoist encoding ops out of dispatch regi…
Browse files Browse the repository at this point in the history
…ons (iree-org#18063)

This PR adds a new pass in the Flow data tiling pipeline to hoist
encoding ops out of their dispatch regions. After SetEncoding, the
encoding ops are inserted directly inside of the dispatch regions that
contain the data tiled ops. The set_encoding ops then need to be hoisted
out of the dispatch region in order to fuse into the producer dispatch.

Sometimes there may be producer operations fused into the same dispatch
as the data tiled op, in which case the set_encoding ops will have
producers inside of the dispatch. In order to hoist the set_encoding op,
it needs to be bubbled up through these producer operations until it has
no producers inside of its dispatch. This pass supports bubbling of
set_encoding ops through bit extending ops and broadcasting ops.

After this pass, all set_encoding ops should be outside of dispatch
regions, and they need to be fused with their producers. Another pass
will be added in the next PR to fuse set_encoding ops into their
producer dispatch regions or wrap them in a new dispatch region.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Aug 23, 2024
1 parent c6924b6 commit a0945cc
Show file tree
Hide file tree
Showing 14 changed files with 692 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def EncodingAttr :

/// Returns an integer array with values in `round_dims_to`.
ArrayRef<int64_t> getRoundDimsToArray();

/// Clones an encoding with a new bcast_map
EncodingAttr clone(AffineMap bcastMap);
}];

let genVerifyDecl = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() {
return llvm::cast<DenseI64ArrayAttr>(roundDimsTo).asArrayRef();
}

EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
return get(bcastMap.getContext(), getOperandIndex(), getOpType(),
getElementTypes(), getOriginalType(), getMatmulNarrow_M(),
getMatmulNarrow_N(), getUserIndexingMaps(),
AffineMapAttr::get(bcastMap), getRoundDimsTo());
}

//===---------------------------------------------------------------------===//
// Encoding Dialect Helpers
//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ iree_compiler_cc_library(
"FuseMultiUseElementwiseProducer.cpp",
"FusionPreprocessing.cpp",
"FusionUtils.cpp",
"HoistEncodingOps.cpp",
"InitializeEmptyTensors.cpp",
"InjectDispatchTracing.cpp",
"InjectTensorTracing.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ iree_cc_library(
"FuseMultiUseElementwiseProducer.cpp"
"FusionPreprocessing.cpp"
"FusionUtils.cpp"
"HoistEncodingOps.cpp"
"InitializeEmptyTensors.cpp"
"InjectDispatchTracing.cpp"
"InjectTensorTracing.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-flow-hoist-encoding-ops"

namespace mlir::iree_compiler::IREE::Flow {
#define GEN_PASS_DEF_HOISTENCODINGOPSPASS
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"

static AffineMap getBcastMapOrIdentity(RewriterBase &rewriter,
RankedTensorType encodedType) {
auto encoding = cast<IREE::Encoding::EncodingAttr>(encodedType.getEncoding());
AffineMapAttr bcastMapAttr = encoding.getBcastMap();
return bcastMapAttr ? bcastMapAttr.getAffineMap()
: rewriter.getMultiDimIdentityMap(encodedType.getRank());
}

/// Bubbles a SetEncodingOp up through a linalg::GenericOp. The `genericOp`
/// must:
/// 1. Have a single result.
/// 2. Have single use.
/// 3. Have all parallel iterators.
/// 4. Have an identity output indexing map.
/// 5. Have a tensor.empty init operand.
/// 6. Have as many indexing map dims as there are results in the encoding's
/// bcast_map.
///
/// This function creates SetEncoding ops on all of the inputs to the
/// `genericOp`, and replaces the op with an encoded version. If any of
/// the above conditions are false, then it returns failure.
///
/// Note: The bcast_map on the set_encoding op must be identity or absent.
/// The implementation should work for cases where it is not, but it is
/// unexpected in IREE compilation to find such cases, and it will not
/// be well tested.
static LogicalResult
bubbleUpSetEncodingThroughGenericOp(RewriterBase &rewriter,
Encoding::SetEncodingOp encodingOp,
linalg::GenericOp genericOp) {
if (!genericOp->hasOneUse()) {
return rewriter.notifyMatchFailure(genericOp,
"genericOp must have one use");
}
if (genericOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(genericOp,
"genericOp must have a single init");
}
if (genericOp.getNumReductionLoops() != 0) {
return rewriter.notifyMatchFailure(
genericOp, "genericOp must have all parallel loops");
}
if (!genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>()) {
return rewriter.notifyMatchFailure(genericOp,
"init operand must be tensor.empty");
}
AffineMap outputMap =
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
if (!outputMap.isIdentity()) {
return rewriter.notifyMatchFailure(genericOp, "output map not identity");
}

RankedTensorType encodedType = encodingOp.getResultType();
AffineMap bcastMap = getBcastMapOrIdentity(rewriter, encodedType);
if (!bcastMap.isIdentity()) {
return rewriter.notifyMatchFailure(genericOp, "bcast_map map not identity");
}
if (outputMap.getNumDims() != bcastMap.getNumResults()) {
return rewriter.notifyMatchFailure(
genericOp, "output map numDims do not match bcast_map numResults");
}

// Set encodings on each input
Location loc = genericOp->getLoc();
SmallVector<Value> encodedOperands;
auto encoding = cast<IREE::Encoding::EncodingAttr>(encodedType.getEncoding());
for (OpOperand *operand : genericOp.getDpsInputOperands()) {
// Compute the new bcastMap from the operand's indexing map.
AffineMap operandMap = genericOp.getMatchingIndexingMap(operand);
AffineMap newBcastMap = operandMap.compose(bcastMap);

// Create new encoding and set encoding on the operand.
auto newEncoding = encoding.clone(newBcastMap);
auto operandType = cast<RankedTensorType>(operand->get().getType());
auto resType = RankedTensorType::get(
operandType.getShape(), operandType.getElementType(), newEncoding);
Value encodedInput =
rewriter.create<Encoding::SetEncodingOp>(loc, resType, operand->get());
encodedOperands.push_back(encodedInput);
}

// Create encoded generic op.
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(rewriter, loc, encodingOp.getSource());
Value encodedInit = rewriter.create<tensor::EmptyOp>(
loc, mixedSizes, encodedType.getElementType(), encoding);
encodedOperands.push_back(encodedInit);
auto encodedGenericOp =
clone(rewriter, genericOp, encodingOp.getResultType(), encodedOperands);

rewriter.replaceOp(encodingOp, encodedGenericOp);
return success();
}

static LogicalResult bubbleUpSetEncoding(RewriterBase &rewriter,
OpOperand &operand) {
auto setEncoding = cast<Encoding::SetEncodingOp>(operand.getOwner());
auto producer = operand.get().getDefiningOp<linalg::GenericOp>();
if (!producer) {
return failure();
}
// Only bubble through dequantization ops and broadcasting ops for now.
if (!LinalgExt::isBitExtendOp(producer) &&
!LinalgExt::isBroadcastingOp(producer)) {
return failure();
}
return bubbleUpSetEncodingThroughGenericOp(rewriter, setEncoding, producer);
}

namespace {
/// Pass declaration.
struct HoistEncodingOpsPass
: public IREE::Flow::impl::HoistEncodingOpsPassBase<HoistEncodingOpsPass> {
using IREE::Flow::impl::HoistEncodingOpsPassBase<
HoistEncodingOpsPass>::HoistEncodingOpsPassBase;
void runOnOperation() override;
};

/// Pattern to bubble SetEncoding ops upwards through producers. This pattern
/// runs until bubbling is not possible, or until the SetEncoding op is outside
/// of a dispatch.
struct BubbleUpSetEncodingOp
: public OpRewritePattern<Encoding::SetEncodingOp> {
using OpRewritePattern<Encoding::SetEncodingOp>::OpRewritePattern;

LogicalResult matchAndRewrite(Encoding::SetEncodingOp encodingOp,
PatternRewriter &rewriter) const override {
if (isNonNullAndOutsideDispatch(encodingOp)) {
return failure();
}
// Fail if the encodingOp is not in the same dispatch as its producer.
Operation *producer = encodingOp.getSource().getDefiningOp();
if (!producer) {
return failure();
}
auto dispatch = producer->getParentOfType<DispatchRegionOp>();
if (!dispatch ||
dispatch != encodingOp->getParentOfType<DispatchRegionOp>()) {
return failure();
}

return bubbleUpSetEncoding(rewriter, encodingOp->getOpOperand(0));
}
};

} // namespace

/// Create dispatch.region Ops based on a fusion heuristic.
void HoistEncodingOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

RewritePatternSet bubblingPatterns(ctx);
bubblingPatterns.insert<BubbleUpSetEncodingOp>(ctx);
if (failed(
applyPatternsAndFoldGreedily(funcOp, std::move(bubblingPatterns)))) {
return signalPassFailure();
}

SmallVector<Encoding::SetEncodingOp> candidates;
funcOp->walk([&](Encoding::SetEncodingOp setEncodingOp) {
if (setEncodingOp->getParentOfType<DispatchRegionOp>()) {
candidates.push_back(setEncodingOp);
}
});
IRRewriter rewriter(ctx);
for (auto setEncodingOp : candidates) {
if (failed(hoistOutOfDispatch(rewriter, setEncodingOp))) {
return signalPassFailure();
}
}

RewritePatternSet cleanPatterns(ctx);
memref::populateResolveRankedShapedTypeResultDimsPatterns(cleanPatterns);
DispatchRegionOp::getCanonicalizationPatterns(cleanPatterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(cleanPatterns)))) {
return signalPassFailure();
}
}
} // namespace mlir::iree_compiler::IREE::Flow
24 changes: 18 additions & 6 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,24 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
// afterwards that would need the full dispatch content but don't want to
// handle explicit captures as materialized as dispatch workgroup operands
// and block arguments.
.addPass(IREE::Flow::createCloneProducersIntoDispatchRegionsPass)
.addPredicatedPass(clEnableDataTiling,
[&]() {
return createSetEncodingPass(
SetEncodingPassOptions{clPadFactor});
})
.addPass(IREE::Flow::createCloneProducersIntoDispatchRegionsPass);
// Experimental data tiling path. The intent of this path is to set encodings
// after fusion decisions have already been made, so encodings can be
// separated from compiler fusion decisions.
if (clEnableDataTiling) {
SetEncodingPassOptions options{clPadFactor};
FunctionLikeNest(passManager)
// Set encodings on all eligible ops. All ops should be in compiler
// formed dispatch regions, so encodings will be placed inside of the
// dispatch regions with the data-tiled op.
.addPass([&]() { return createSetEncodingPass(options); })
// SetEncodingOps should not be in the same dispatch as the data-tiled
// op, so hoist them out of their current dispatch regions. Also, bubble
// SetEncodingOps through special operations like bit-extending ops and
// broadcasting ops.
.addPass(IREE::Flow::createHoistEncodingOpsPass);
}
FunctionLikeNest(passManager)
// Collapse dimensions of linalg Ops.
.addPass(IREE::Flow::createCollapseDimensionsPass)
// Convert dispatch regions into dispatch workgroups by capturing values
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@ def ExportBenchmarkFuncsPass :
}


def HoistEncodingOpsPass :
InterfacePass<"iree-flow-hoist-encoding-ops", "mlir::FunctionOpInterface"> {
let summary = "Hoists tensor encoding ops out of flow dispatch regions.";
let dependentDialects = [
"mlir::linalg::LinalgDialect",
"IREE::Flow::FlowDialect",
"IREE::Encoding::IREEEncodingDialect",
];
}

def InitializeEmptyTensorsPass :
Pass<"iree-flow-initialize-empty-tensors", ""> {
let summary = "Initialize empty tensors.";
Expand Down
Loading

0 comments on commit a0945cc

Please sign in to comment.