From a0945cc1aafbe2c9eabc807ddb6e8397ec6d0716 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 23 Aug 2024 08:52:52 -0700 Subject: [PATCH] [Flow] Add pass to bubble and hoist encoding ops out of dispatch regions (#18063) This PR adds a new pass in the Flow data tiling pipeline to hoist encoding ops out of their dispatch regions. After SetEncoding, the encoding ops are inserted directly inside of the dispatch regions that contain the data tiled ops. The set_encoding ops then need to be hoisted out of the dispatch region in order to fuse into the producer dispatch. Sometimes there may be producer operations fused into the same dispatch as the data tiled op, in which case the set_encoding ops will have producers inside of the dispatch. In order to hoist the set_encoding op, it needs to be bubbled up through these producer operations until it has no producers inside of its dispatch. This pass supports bubbling of set_encoding ops through bit extending ops and broadcasting ops. After this pass, all set_encoding ops should be outside of dispatch regions, and they need to be fused with their producers. Another pass will be added in the next PR to fuse set_encoding ops into their producer dispatch regions or wrap them in a new dispatch region. --------- Signed-off-by: Max Dawkins --- .../Dialect/Encoding/IR/EncodingBase.td | 3 + .../Dialect/Encoding/IR/EncodingOps.cpp | 7 + .../Dialect/Flow/Transforms/BUILD.bazel | 1 + .../Dialect/Flow/Transforms/CMakeLists.txt | 1 + .../Flow/Transforms/HoistEncodingOps.cpp | 217 ++++++++++++++++++ .../Dialect/Flow/Transforms/Passes.cpp | 24 +- .../Dialect/Flow/Transforms/Passes.td | 10 + .../Dialect/Flow/Transforms/RegionOpUtils.cpp | 177 ++++++++++++++ .../Dialect/Flow/Transforms/RegionOpUtils.h | 12 + .../Dialect/Flow/Transforms/test/BUILD.bazel | 1 + .../Flow/Transforms/test/CMakeLists.txt | 1 + .../Transforms/test/hoist_encoding_ops.mlir | 192 ++++++++++++++++ .../Dialect/LinalgExt/Utils/Utils.cpp | 40 ++++ .../compiler/Dialect/LinalgExt/Utils/Utils.h | 12 + 14 files changed, 692 insertions(+), 6 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td index dc89604a6cfb..befda327ce15 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td @@ -113,6 +113,9 @@ def EncodingAttr : /// Returns an integer array with values in `round_dims_to`. ArrayRef getRoundDimsToArray(); + + /// Clones an encoding with a new bcast_map + EncodingAttr clone(AffineMap bcastMap); }]; let genVerifyDecl = 0; diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp index 01bcd7e1febb..0c3ef6dc6a9b 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp @@ -153,6 +153,13 @@ ArrayRef EncodingAttr::getRoundDimsToArray() { return llvm::cast(roundDimsTo).asArrayRef(); } +EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { + return get(bcastMap.getContext(), getOperandIndex(), getOpType(), + getElementTypes(), getOriginalType(), getMatmulNarrow_M(), + getMatmulNarrow_N(), getUserIndexingMaps(), + AffineMapAttr::get(bcastMap), getRoundDimsTo()); +} + //===---------------------------------------------------------------------===// // Encoding Dialect Helpers //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index fe5eaabd336a..b790a99defb6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -55,6 +55,7 @@ iree_compiler_cc_library( "FuseMultiUseElementwiseProducer.cpp", "FusionPreprocessing.cpp", "FusionUtils.cpp", + "HoistEncodingOps.cpp", "InitializeEmptyTensors.cpp", "InjectDispatchTracing.cpp", "InjectTensorTracing.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 701eefe5758d..7bbb5d5316d9 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -55,6 +55,7 @@ iree_cc_library( "FuseMultiUseElementwiseProducer.cpp" "FusionPreprocessing.cpp" "FusionUtils.cpp" + "HoistEncodingOps.cpp" "InitializeEmptyTensors.cpp" "InjectDispatchTracing.cpp" "InjectTensorTracing.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp new file mode 100644 index 000000000000..0068744fa12b --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp @@ -0,0 +1,217 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" +#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-flow-hoist-encoding-ops" + +namespace mlir::iree_compiler::IREE::Flow { +#define GEN_PASS_DEF_HOISTENCODINGOPSPASS +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc" + +static AffineMap getBcastMapOrIdentity(RewriterBase &rewriter, + RankedTensorType encodedType) { + auto encoding = cast(encodedType.getEncoding()); + AffineMapAttr bcastMapAttr = encoding.getBcastMap(); + return bcastMapAttr ? bcastMapAttr.getAffineMap() + : rewriter.getMultiDimIdentityMap(encodedType.getRank()); +} + +/// Bubbles a SetEncodingOp up through a linalg::GenericOp. The `genericOp` +/// must: +/// 1. Have a single result. +/// 2. Have single use. +/// 3. Have all parallel iterators. +/// 4. Have an identity output indexing map. +/// 5. Have a tensor.empty init operand. +/// 6. Have as many indexing map dims as there are results in the encoding's +/// bcast_map. +/// +/// This function creates SetEncoding ops on all of the inputs to the +/// `genericOp`, and replaces the op with an encoded version. If any of +/// the above conditions are false, then it returns failure. +/// +/// Note: The bcast_map on the set_encoding op must be identity or absent. +/// The implementation should work for cases where it is not, but it is +/// unexpected in IREE compilation to find such cases, and it will not +/// be well tested. +static LogicalResult +bubbleUpSetEncodingThroughGenericOp(RewriterBase &rewriter, + Encoding::SetEncodingOp encodingOp, + linalg::GenericOp genericOp) { + if (!genericOp->hasOneUse()) { + return rewriter.notifyMatchFailure(genericOp, + "genericOp must have one use"); + } + if (genericOp.getNumDpsInits() != 1) { + return rewriter.notifyMatchFailure(genericOp, + "genericOp must have a single init"); + } + if (genericOp.getNumReductionLoops() != 0) { + return rewriter.notifyMatchFailure( + genericOp, "genericOp must have all parallel loops"); + } + if (!genericOp.getDpsInitOperand(0)->get().getDefiningOp()) { + return rewriter.notifyMatchFailure(genericOp, + "init operand must be tensor.empty"); + } + AffineMap outputMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + if (!outputMap.isIdentity()) { + return rewriter.notifyMatchFailure(genericOp, "output map not identity"); + } + + RankedTensorType encodedType = encodingOp.getResultType(); + AffineMap bcastMap = getBcastMapOrIdentity(rewriter, encodedType); + if (!bcastMap.isIdentity()) { + return rewriter.notifyMatchFailure(genericOp, "bcast_map map not identity"); + } + if (outputMap.getNumDims() != bcastMap.getNumResults()) { + return rewriter.notifyMatchFailure( + genericOp, "output map numDims do not match bcast_map numResults"); + } + + // Set encodings on each input + Location loc = genericOp->getLoc(); + SmallVector encodedOperands; + auto encoding = cast(encodedType.getEncoding()); + for (OpOperand *operand : genericOp.getDpsInputOperands()) { + // Compute the new bcastMap from the operand's indexing map. + AffineMap operandMap = genericOp.getMatchingIndexingMap(operand); + AffineMap newBcastMap = operandMap.compose(bcastMap); + + // Create new encoding and set encoding on the operand. + auto newEncoding = encoding.clone(newBcastMap); + auto operandType = cast(operand->get().getType()); + auto resType = RankedTensorType::get( + operandType.getShape(), operandType.getElementType(), newEncoding); + Value encodedInput = + rewriter.create(loc, resType, operand->get()); + encodedOperands.push_back(encodedInput); + } + + // Create encoded generic op. + SmallVector mixedSizes = + tensor::getMixedSizes(rewriter, loc, encodingOp.getSource()); + Value encodedInit = rewriter.create( + loc, mixedSizes, encodedType.getElementType(), encoding); + encodedOperands.push_back(encodedInit); + auto encodedGenericOp = + clone(rewriter, genericOp, encodingOp.getResultType(), encodedOperands); + + rewriter.replaceOp(encodingOp, encodedGenericOp); + return success(); +} + +static LogicalResult bubbleUpSetEncoding(RewriterBase &rewriter, + OpOperand &operand) { + auto setEncoding = cast(operand.getOwner()); + auto producer = operand.get().getDefiningOp(); + if (!producer) { + return failure(); + } + // Only bubble through dequantization ops and broadcasting ops for now. + if (!LinalgExt::isBitExtendOp(producer) && + !LinalgExt::isBroadcastingOp(producer)) { + return failure(); + } + return bubbleUpSetEncodingThroughGenericOp(rewriter, setEncoding, producer); +} + +namespace { +/// Pass declaration. +struct HoistEncodingOpsPass + : public IREE::Flow::impl::HoistEncodingOpsPassBase { + using IREE::Flow::impl::HoistEncodingOpsPassBase< + HoistEncodingOpsPass>::HoistEncodingOpsPassBase; + void runOnOperation() override; +}; + +/// Pattern to bubble SetEncoding ops upwards through producers. This pattern +/// runs until bubbling is not possible, or until the SetEncoding op is outside +/// of a dispatch. +struct BubbleUpSetEncodingOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Encoding::SetEncodingOp encodingOp, + PatternRewriter &rewriter) const override { + if (isNonNullAndOutsideDispatch(encodingOp)) { + return failure(); + } + // Fail if the encodingOp is not in the same dispatch as its producer. + Operation *producer = encodingOp.getSource().getDefiningOp(); + if (!producer) { + return failure(); + } + auto dispatch = producer->getParentOfType(); + if (!dispatch || + dispatch != encodingOp->getParentOfType()) { + return failure(); + } + + return bubbleUpSetEncoding(rewriter, encodingOp->getOpOperand(0)); + } +}; + +} // namespace + +/// Create dispatch.region Ops based on a fusion heuristic. +void HoistEncodingOpsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + auto funcOp = getOperation(); + + RewritePatternSet bubblingPatterns(ctx); + bubblingPatterns.insert(ctx); + if (failed( + applyPatternsAndFoldGreedily(funcOp, std::move(bubblingPatterns)))) { + return signalPassFailure(); + } + + SmallVector candidates; + funcOp->walk([&](Encoding::SetEncodingOp setEncodingOp) { + if (setEncodingOp->getParentOfType()) { + candidates.push_back(setEncodingOp); + } + }); + IRRewriter rewriter(ctx); + for (auto setEncodingOp : candidates) { + if (failed(hoistOutOfDispatch(rewriter, setEncodingOp))) { + return signalPassFailure(); + } + } + + RewritePatternSet cleanPatterns(ctx); + memref::populateResolveRankedShapedTypeResultDimsPatterns(cleanPatterns); + DispatchRegionOp::getCanonicalizationPatterns(cleanPatterns, ctx); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(cleanPatterns)))) { + return signalPassFailure(); + } +} +} // namespace mlir::iree_compiler::IREE::Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 4c522a4e200b..5ce1211708e6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -279,12 +279,24 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager) { // afterwards that would need the full dispatch content but don't want to // handle explicit captures as materialized as dispatch workgroup operands // and block arguments. - .addPass(IREE::Flow::createCloneProducersIntoDispatchRegionsPass) - .addPredicatedPass(clEnableDataTiling, - [&]() { - return createSetEncodingPass( - SetEncodingPassOptions{clPadFactor}); - }) + .addPass(IREE::Flow::createCloneProducersIntoDispatchRegionsPass); + // Experimental data tiling path. The intent of this path is to set encodings + // after fusion decisions have already been made, so encodings can be + // separated from compiler fusion decisions. + if (clEnableDataTiling) { + SetEncodingPassOptions options{clPadFactor}; + FunctionLikeNest(passManager) + // Set encodings on all eligible ops. All ops should be in compiler + // formed dispatch regions, so encodings will be placed inside of the + // dispatch regions with the data-tiled op. + .addPass([&]() { return createSetEncodingPass(options); }) + // SetEncodingOps should not be in the same dispatch as the data-tiled + // op, so hoist them out of their current dispatch regions. Also, bubble + // SetEncodingOps through special operations like bit-extending ops and + // broadcasting ops. + .addPass(IREE::Flow::createHoistEncodingOpsPass); + } + FunctionLikeNest(passManager) // Collapse dimensions of linalg Ops. .addPass(IREE::Flow::createCollapseDimensionsPass) // Convert dispatch regions into dispatch workgroups by capturing values diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index 8fe891be6f6d..fabde78af4b1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -437,6 +437,16 @@ def ExportBenchmarkFuncsPass : } +def HoistEncodingOpsPass : + InterfacePass<"iree-flow-hoist-encoding-ops", "mlir::FunctionOpInterface"> { + let summary = "Hoists tensor encoding ops out of flow dispatch regions."; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", + "IREE::Flow::FlowDialect", + "IREE::Encoding::IREEEncodingDialect", + ]; +} + def InitializeEmptyTensorsPass : Pass<"iree-flow-initialize-empty-tensors", ""> { let summary = "Initialize empty tensors."; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 4b1d01757456..4b0e5fe87697 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "mlir/Analysis/SliceAnalysis.h" @@ -528,6 +529,182 @@ wrapOpInDispatchRegion(RewriterBase &rewriter, Operation *op) { return newRegionOp; } +FailureOr hoistOutOfDispatch(RewriterBase &rewriter, + Operation *op) { + assert(op && !isNonNullAndOutsideDispatch(op) && + "op expected to be in a dispatch"); + + // Step 1: Clone the op outside of the dispatch region. + + OpBuilder::InsertionGuard g(rewriter); + auto dispatchRegionOp = op->getParentOfType(); + + // If all operands of the `op` come from outside the dispatch, then the op can + // be hoisted out before the dispatch region. Otherwise, the op can be hoisted + // out below the dispatch if the only users of the op are the dispatch return. + if (llvm::none_of(op->getOperands(), [&](Value operand) { + Operation *producer = operand.getDefiningOp(); + return producer && producer->getParentOfType(); + })) { + rewriter.setInsertionPoint(dispatchRegionOp); + } else if (llvm::all_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + rewriter.setInsertionPointAfter(dispatchRegionOp); + } else { + return rewriter.notifyMatchFailure( + op, "op has both operands and users insided of its dispatch"); + } + Operation *hoistedOp = rewriter.clone(*op); + + // Step 2: Replace op uses inside and outside of the dispatch region with the + // hoisted results. + + auto getMatchingDispatchResult = + [&](Value result) -> std::optional { + for (OpOperand &use : result.getUses()) { + if (isa(use.getOwner())) { + return dispatchRegionOp.getResults()[use.getOperandNumber()]; + } + } + return std::nullopt; + }; + bool yieldsResults = false; + for (OpResult result : op->getResults()) { + Value hoistedResult = hoistedOp->getResult(result.getResultNumber()); + // Replace all results yielded by the dispatch region with the hoisted + // op results. + std::optional dispResult = getMatchingDispatchResult(result); + if (dispResult.has_value()) { + yieldsResults = true; + rewriter.replaceAllUsesWith(dispResult.value(), hoistedResult); + } + // Replace uses inside the dispatch region. + rewriter.replaceAllUsesWith(result, hoistedResult); + } + // If no results were yielded from `op`, then nothing more to do. + if (!yieldsResults) { + return hoistedOp; + } + + // Step 3: Collect the new set of dispatch results and dynamic dims, and + // create a new dispatch region to replace the old one. The new + // dispatch may have duplicated results, + + // Get the new dispatch region return values and dynamic dims, excluding the + // ones coming from the `hoistedOp`. + auto dispatchReturnOp = cast( + dispatchRegionOp.getBody().front().getTerminator()); + SmallVector newDispatchReturnOperands; + SmallVector newDispatchResultDynamicDims; + // Keep track of which results in the original dispatch region correspond to + // which results in the new dispatch region with `oldDispatchResultInds`. + SmallVector oldDispatchResultInds; + for (OpOperand &operand : dispatchReturnOp->getOpOperands()) { + if (operand.get().getDefiningOp() == hoistedOp) { + continue; + } + oldDispatchResultInds.push_back(operand.getOperandNumber()); + newDispatchReturnOperands.push_back(operand.get()); + auto dims = + dispatchRegionOp.getResultDynamicDims(operand.getOperandNumber()); + newDispatchResultDynamicDims.append(dims.begin(), dims.end()); + } + + // Add the operands of the `op` to the new return values of the dispatch, and + // add their result dynamic dims to the new result dynamic dims. + // Save the result index in the new dispatch corresponding to each hoisted op + // operand in `resultIndsForHoistedOperands`, so uses can be replaced later. + SmallVector resultIndsForHoistedOperands; + for (OpOperand &operand : op->getOpOperands()) { + // Only need to yield operands defined in the dispatch region. + if (operand.get().getParentRegion() != &dispatchRegionOp.getBody()) { + continue; + } + + // If the operand is already yielded by the dispatch, don't yield it again, + // and save the result index. + bool resultAlreadyYielded = false; + for (auto [idx, returnOperand] : + llvm::enumerate(newDispatchReturnOperands)) { + if (returnOperand == operand.get()) { + resultAlreadyYielded = true; + resultIndsForHoistedOperands.push_back(idx); + break; + } + } + if (resultAlreadyYielded) { + break; + } + resultIndsForHoistedOperands.push_back(newDispatchReturnOperands.size()); + newDispatchReturnOperands.push_back(operand.get()); + + // Save operand and dynamic dims to add to the dispatch region. + SmallVector dims; + if (failed(reifyDynamicResultDims(rewriter, operand.get(), dims))) { + return op->emitOpError( + "failed to reify dynamic dims of result to be yielded from " + "dispatch region"); + } + newDispatchResultDynamicDims.append(dims.begin(), dims.end()); + } + + // Create the new dispatch region op. `newDispatchReturnOperands` now has all + // the original return operands, excluding the hoisted op's results, and + // including any new results coming from the hoisted op's old operands. The + // `newDispatchResultDynamicDims` contains the corresponding result dynamic + // dims for `newDispatchReturnOperands`. + SmallVector newResultTypes = + llvm::map_to_vector(newDispatchReturnOperands, + [](Value operand) { return operand.getType(); }); + rewriter.setInsertionPoint(dispatchRegionOp); + auto newDispatchRegionOp = rewriter.create( + dispatchRegionOp->getLoc(), newResultTypes, newDispatchResultDynamicDims, + dispatchRegionOp.getWorkload()); + rewriter.inlineRegionBefore(dispatchRegionOp.getBody(), + newDispatchRegionOp.getBody(), + newDispatchRegionOp.getBody().begin()); + // Move the workgroup count region over. + if (!dispatchRegionOp.getWorkgroupCount().empty()) { + Region &newWorkgroupCountRegion = newDispatchRegionOp.getWorkgroupCount(); + rewriter.inlineRegionBefore(dispatchRegionOp.getWorkgroupCount(), + newWorkgroupCountRegion, + newWorkgroupCountRegion.begin()); + } + // Need to make a new flow.return op, since the body was copied from the + // old dispatch region. + auto newDispatchReturnOp = cast( + newDispatchRegionOp.getBody().front().getTerminator()); + rewriter.setInsertionPoint(newDispatchReturnOp); + rewriter.replaceOpWithNewOp(newDispatchReturnOp, + newDispatchReturnOperands); + + // Replace operands of the `hoistedOp` with dispatch region results. They are + // currently using values from inside the dispatch region. + for (auto [idx, operand] : llvm::enumerate(hoistedOp->getOperands())) { + auto newResultIdx = resultIndsForHoistedOperands[idx]; + Value newDispatchResult = newDispatchRegionOp->getResults()[newResultIdx]; + rewriter.replaceUsesWithIf(operand, newDispatchResult, + [&](OpOperand &opOperand) { + return opOperand.getOwner() == hoistedOp; + }); + } + + // Step 4: Fixup all uses. Still need to replace the operands of the hoisted + // op, and replace the remaining uses of the old dispatch region with + // the new dispatch region results. + + // Replace the uses of the original dispatch region results with the final + // dispatch region results. + for (auto [oldIdx, newIdx] : llvm::enumerate(oldDispatchResultInds)) { + Value newDispatchResult = newDispatchRegionOp->getResults()[newIdx]; + Value dispatchResult = dispatchRegionOp->getResults()[oldIdx]; + rewriter.replaceAllUsesWith(dispatchResult, newDispatchResult); + } + + return hoistedOp; +} + //===---------------------------------------------------------------------===// // Utilities to make a dispatch region isolated from above //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h index 45a375f447d0..1451cc30ef29 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h @@ -104,6 +104,18 @@ FailureOr wrapOpInDispatchRegion(RewriterBase &rewriter, /// into a dispatch region. bool isClonableIntoDispatchOp(Operation *op); +/// Hoists an operation out of a dispatch region, as long as it does not have +/// producers inside of the dispatch region, or all of its uses are part of +/// the dispatch region op return. If these criteria are not met, then return +/// failure. +/// +/// If all producers are defined outside of the dispatch region, then the op +/// will be hoisted above the dispatch region op. Otherwise, the op will be +/// hoisted below the dispatch region op, and the operands of the hoisted op +/// will be added to the yielded values of the dispatch region op. +FailureOr hoistOutOfDispatch(RewriterBase &rewriter, + Operation *op); + /// Collect all ops that should be cloned into the given dispatch region op. SmallVector getCloneableOps(Flow::DispatchRegionOp regionOp); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel index ee4f9369d114..1c1692563103 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel @@ -42,6 +42,7 @@ iree_lit_test_suite( "fuse_horizontal_contractions.mlir", "fuse_multiuse_elementwise_producer.mlir", "fusion_preprocessing.mlir", + "hoist_encoding_ops.mlir", "initialize_empty_tensors.mlir", "inject_dispatch_tracing.mlir", "inject_tensor_tracing.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index a88bb4ff46f5..203a4ad9d7ad 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -40,6 +40,7 @@ iree_lit_test_suite( "fuse_horizontal_contractions.mlir" "fuse_multiuse_elementwise_producer.mlir" "fusion_preprocessing.mlir" + "hoist_encoding_ops.mlir" "initialize_empty_tensors.mlir" "inject_dispatch_tracing.mlir" "inject_tensor_tracing.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir new file mode 100644 index 000000000000..ca97bb0f1ff6 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir @@ -0,0 +1,192 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-hoist-encoding-ops))" --split-input-file %s | FileCheck %s + +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#lhs_encoding = #iree_encoding.encoding, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array> +#rhs_encoding = #iree_encoding.encoding, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array> +#result_encoding = #iree_encoding.encoding, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array> +module { + util.func public @hoist_matmul_encodings(%arg0: tensor<2x128x64xf32>, %arg1: tensor<2x11008x128xf32>) -> tensor<2x11008x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %2 = flow.dispatch.region -> (tensor<2x11008x64xf32>) { + %3 = iree_encoding.set_encoding %arg0 : tensor<2x128x64xf32> -> tensor<2x128x64xf32, #lhs_encoding> + %4 = iree_encoding.set_encoding %arg1 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #rhs_encoding> + %5 = tensor.empty() : tensor<2x11008x64xf32, #result_encoding> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x11008x64xf32, #result_encoding>) -> tensor<2x11008x64xf32, #result_encoding> + %7 = linalg.generic { + indexing_maps = [#map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%3, %4 : tensor<2x128x64xf32, #lhs_encoding>, tensor<2x11008x128xf32, #rhs_encoding>) + outs(%6 : tensor<2x11008x64xf32, #result_encoding>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %9 = arith.mulf %in, %in_0 : f32 + %10 = arith.addf %9, %out : f32 + linalg.yield %10 : f32 + } -> tensor<2x11008x64xf32, #result_encoding> + %8 = iree_encoding.unset_encoding %7 : tensor<2x11008x64xf32, #result_encoding> -> tensor<2x11008x64xf32> + flow.return %8 : tensor<2x11008x64xf32> + } + util.return %2 : tensor<2x11008x64xf32> + } +} + +// CHECK-LABEL: @hoist_matmul_encodings +// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x128x64xf32>, %[[ARG1:.+]]: tensor<2x11008x128xf32>) +// CHECK-DAG: %[[SET_ENCODING0:.+]] = iree_encoding.set_encoding %[[ARG0]] : tensor<2x128x64xf32> -> tensor<2x128x64xf32, #iree_encoding.encoding +// CHECK-DAG: %[[SET_ENCODING1:.+]] = iree_encoding.set_encoding %[[ARG1]] : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #iree_encoding.encoding +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<2x11008x64xf32>) { +// CHECK: %[[MATMUL:.+]] = linalg.generic {{.*}} ins(%[[SET_ENCODING0]], %[[SET_ENCODING1]] +// CHECK: %[[UNSET_ENCODING1:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<2x11008x64xf32, #iree_encoding.encoding +// CHECK: flow.return %[[UNSET_ENCODING1]] : tensor<2x11008x64xf32> +// CHECK: } +// CHECK: util.return %[[DISPATCH]] : tensor<2x11008x64xf32> + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding = #iree_encoding.encoding, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], round_dims_to = array> +util.func public @bubble_through_dequant( + %arg0: tensor<2x11008x128xi8>, %arg1: tensor<2x11008xf32>, %arg2: tensor<2x11008xf32>) -> tensor<2x11008x128xf32, #encoding> { + %6 = flow.dispatch.region -> (tensor<2x11008x128xf32, #encoding>) { + %8 = tensor.empty() : tensor<2x11008x128xf32> + %11 = linalg.generic + {indexing_maps = [#map, #map1, #map1, #map], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor<2x11008x128xi8>, tensor<2x11008xf32>, tensor<2x11008xf32>) + outs(%8 : tensor<2x11008x128xf32>) { + ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32): + %18 = arith.extui %in : i8 to i32 + %19 = arith.uitofp %18 : i32 to f32 + %20 = arith.subf %19, %in_1 : f32 + %21 = arith.mulf %20, %in_0 : f32 + linalg.yield %21 : f32 + } -> tensor<2x11008x128xf32> + %13 = iree_encoding.set_encoding %11 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding> + flow.return %13 : tensor<2x11008x128xf32, #encoding> + } + util.return %6 : tensor<2x11008x128xf32, #encoding> +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: @bubble_through_dequant +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x11008x128xi8>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<2x11008xf32>, %[[ARG2:.+]]: tensor<2x11008xf32> +// CHECK-DAG: %[[SET_ENCODING0:.+]] = iree_encoding.set_encoding %[[ARG0]] : {{.*}} bcast_map = #[[$MAP4]] +// CHECK-DAG: %[[SET_ENCODING1:.+]] = iree_encoding.set_encoding %[[ARG1]] : {{.*}} bcast_map = #[[$MAP3]] +// CHECK-DAG: %[[SET_ENCODING2:.+]] = iree_encoding.set_encoding %[[ARG2]] : {{.*}} bcast_map = #[[$MAP3]] +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x11008x128xf32, #iree_encoding.encoding +// CHECK: %[[DEQUANT:.+]] = linalg.generic {{.*}} ins(%[[SET_ENCODING0]], %[[SET_ENCODING1]], %[[SET_ENCODING2]] : {{.*}} outs(%[[INIT]] : +// CHECK: flow.return %[[DEQUANT]] +// CHECK: } +// CHECK: util.return %[[DISPATCH]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#encoding = #iree_encoding.encoding, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], round_dims_to = array> +util.func public @bubble_through_broadcast( + %arg0: tensor<11008x128xf32>) -> tensor<2x11008x128xf32, #encoding> { + %6 = flow.dispatch.region -> (tensor<2x11008x128xf32, #encoding>) { + %8 = tensor.empty() : tensor<2x11008x128xf32> + %11 = linalg.generic + {indexing_maps = [#map1, #map], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<11008x128xf32>) + outs(%8 : tensor<2x11008x128xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x11008x128xf32> + %13 = iree_encoding.set_encoding %11 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding> + flow.return %13 : tensor<2x11008x128xf32, #encoding> + } + util.return %6 : tensor<2x11008x128xf32, #encoding> +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: @bubble_through_broadcast +// CHECK-SAME: %[[ARG0:.+]]: tensor<11008x128xf32> +// CHECK-DAG: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[ARG0]] : {{.*}} bcast_map = #[[$MAP3]] +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x11008x128xf32, #iree_encoding.encoding +// CHECK: %[[BROADCAST:.+]] = linalg.generic {{.*}} ins(%[[SET_ENCODING]] : {{.*}} outs(%[[INIT]] : +// CHECK: flow.return %[[BROADCAST]] +// CHECK: } +// CHECK: util.return %[[DISPATCH]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#encoding = #iree_encoding.encoding, user_indexing_maps = [#map, #map1, #map2], round_dims_to = array> +module { + util.func public @hoist_below(%arg0: tensor<2x11008x128xf32>) -> tensor<2x11008x128xf32, #encoding> { + %0 = flow.dispatch.region -> (tensor<2x11008x128xf32, #encoding>) { + %1 = tensor.empty() : tensor<2x11008x128xf32> + %2 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor<2x11008x128xf32>, tensor<2x11008x128xf32>) outs(%1 : tensor<2x11008x128xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<2x11008x128xf32> + %3 = iree_encoding.set_encoding %2 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding> + flow.return %3 : tensor<2x11008x128xf32, #encoding> + } + util.return %0 : tensor<2x11008x128xf32, #encoding> + } +} + +// CHECK-LABEL: @hoist_below +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x11008x128xf32> +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x11008x128xf32> +// CHECK: %[[ADD:.+]] = linalg.generic {{.*}} ins(%[[ARG0]], %[[ARG0]] : {{.*}} outs(%[[INIT]] : +// CHECK: flow.return %[[ADD]] +// CHECK: } +// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH]] +// CHECK: util.return %[[SET_ENCODING]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#encoding = #iree_encoding.encoding, user_indexing_maps = [#map, #map1, #map2], round_dims_to = array> +module { + util.func public @hoist_dynamic(%arg0: tensor, %d0: index, %d1: index, %d2: index) -> (tensor, tensor) { + %0:2 = flow.dispatch.region -> (tensor{%d0, %d1, %d2}, tensor{%d0, %d1, %d2}) { + %1 = tensor.empty(%d0, %d1, %d2) : tensor + %2 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor, tensor) outs(%1 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor + %3 = iree_encoding.set_encoding %2 : tensor -> tensor + flow.return %2, %3 : tensor, tensor + } + util.return %0#0, %0#1 : tensor, tensor + } +} + +// CHECK-LABEL: @hoist_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[D0:.+]]: index, %[[D1:.+]]: index, %[[D2:.+]]: index) +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor{%[[D0]], %[[D1]], %[[D2]]}) +// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor +// CHECK: %[[ADD:.+]] = linalg.generic {{.*}} ins(%[[ARG0]], %[[ARG0]] : {{.*}} outs(%[[INIT]] : +// CHECK: flow.return %[[ADD]] +// CHECK: } +// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH]] +// CHECK: util.return %[[DISPATCH]], %[[SET_ENCODING]] diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index 30dccb01a62f..d46c58b7dbd8 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -244,4 +244,44 @@ bool isBitTruncateOp(Operation *op) { return isBitExtendOrTruncateOp(op) == BitWidthChangeInfo::kTruncate; } +//===---------------------------------------------------------------------===// +// Classification of other ops +//===---------------------------------------------------------------------===// + +bool isBroadcastingOp(linalg::LinalgOp op) { + if (isa(op)) { + return true; + } + auto genericOp = dyn_cast(op.getOperation()); + if (!genericOp) { + return false; + } + + // Only allow a single input and init. + if (genericOp.getNumDpsInits() != 1 || genericOp.getNumDpsInputs() != 1) { + return false; + } + + // Check that the all loops are parallel. + unsigned numLoops = genericOp.getNumLoops(); + unsigned numParallelLoops = genericOp.getNumParallelLoops(); + if (numLoops != numParallelLoops) { + return false; + } + + // Check that indexing maps are broadcasting. + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + auto inMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0)); + auto outMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + if (inMap.getNumResults() >= outMap.getNumResults()) { + return false; + } + if (!inMap.isProjectedPermutation() || !outMap.isIdentity()) { + return false; + } + return llvm::hasSingleElement(op.getBlock()->getOperations()); +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index 3c4e139d80ac..d6794af9cad5 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_ #define IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_ +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -127,5 +128,16 @@ bool isBitExtendOp(Operation *op); /// the output element type has a lower bitwidth. bool isBitTruncateOp(Operation *op); +/// Returns true if the operation is a BroadcastOp or a GenericOp performing +/// a broadcast. +/// This function checks that the genericOp: +/// 1. Has a single input and output. +/// 2. Has all parallel loops. +/// 3. Has an identity output map. +/// 4. Has a projected permutation input map. +/// 5. The input map has fewer results than the output map. +/// 6. Has a body with only a linalg.yield op. +bool isBroadcastingOp(linalg::LinalgOp op); + } // namespace mlir::iree_compiler::IREE::LinalgExt #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_