From cb158cf0f60507101cbdf440f82a9c1e8e221092 Mon Sep 17 00:00:00 2001 From: Eliasj42 <46754803+Eliasj42@users.noreply.github.com> Date: Sun, 5 Jun 2022 15:45:40 -0700 Subject: [PATCH 01/38] [CUDA] added fucntion to sync cuda context to current thread (#8) Co-authored-by: Elias Joseph --- runtime/src/iree/hal/drivers/cuda/cuda_device.c | 8 ++++++++ runtime/src/iree/hal/drivers/cuda/cuda_device.h | 2 ++ 2 files changed, 10 insertions(+) diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 565da3ae6d53..deab871ba52a 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -81,6 +81,14 @@ static iree_hal_cuda_device_t* iree_hal_cuda_device_cast_unsafe( return (iree_hal_cuda_device_t*)base_value; } +iree_status_t iree_cuda_set_current_thread(iree_hal_device_t* device){ + iree_hal_cuda_device_t* cuda_device = iree_hal_cuda_device_cast(device); + CUDA_RETURN_IF_ERROR(cuda_device->context_wrapper.syms, + cuCtxSetCurrent(cuda_device->context_wrapper.cu_context), + "cuCtxSetCurrent"); + return iree_ok_status(); +} + IREE_API_EXPORT void iree_hal_cuda_device_params_initialize( iree_hal_cuda_device_params_t* out_params) { memset(out_params, 0, sizeof(*out_params)); diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.h b/runtime/src/iree/hal/drivers/cuda/cuda_device.h index 0cc08870c6a0..2af0c77c68bd 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.h +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.h @@ -41,6 +41,8 @@ CUcontext iree_hal_cuda_device_context(iree_hal_device_t* device); iree_hal_cuda_dynamic_symbols_t* iree_hal_cuda_device_dynamic_symbols( iree_hal_device_t* device); +iree_status_t iree_cuda_set_current_thread(iree_hal_device_t* device); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus From dffde12af549edf44bb1f4ba48a0c8b981313f32 Mon Sep 17 00:00:00 2001 From: stanley Date: Mon, 31 Oct 2022 07:10:06 +0000 Subject: [PATCH 02/38] [vulkan] Modify subspan to handle static cast. --- .../Common/FlattenMemRefSubspanPass.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp index 88d3474402d4..c1685ce0d43b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp @@ -736,6 +736,25 @@ struct RemoveDynamicCastOp final : public OpRewritePattern { } }; +/// Removes memref.cast that turns dynamic shapes into static shapes. +struct RemoveStaticCastOp final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CastOp castOp, + PatternRewriter &rewriter) const override { + auto srcType = castOp.getSource().getType().cast(); + auto dstType = castOp.getType().cast(); + // Restrict to the cases we generate in this pass--1-D static shape to 1-D + // dynamic shape. + if (srcType.getRank() == 1 && !srcType.hasStaticShape() && + dstType.getRank() == 1 && dstType.hasStaticShape()) { + rewriter.replaceOp(castOp, castOp.getSource()); + return success(); + } + return failure(); + } +}; + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// @@ -894,6 +913,7 @@ struct FlattenMemRefSubspanPass memref::AllocaOp::getCanonicalizationPatterns(cleanupPatterns, context); memref::SubViewOp::getCanonicalizationPatterns(cleanupPatterns, context); cleanupPatterns.add(context); + cleanupPatterns.add(context); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(cleanupPatterns)))) { From 3bc784d9b24d8e2b5845e46ec373215f68db3ecd Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 29 Sep 2022 15:44:59 -0400 Subject: [PATCH 03/38] [Flow] Pass to convert NCHW convolutions to NHWC The conversion pass is enabled with `--iree-flow-enable-conv-nchw-to-nhwc-transform` Includes partial support for propagating and cancelling transposes generated when converting from nchw to nhwc. The high level strategy for this pass is as follows: 1. Do the conversions for all conv_nchw_fchw ops (and pooling ops) and wrap the converted convolutions in transposes. Each transpose is tagged to indicate which direction the transpose should propagate through the graph. 2. Traverse the ops in the function in reverse to propagate transposes marked for upwards propagation to their parents. Ideally just before ops such as arith.constant or function arguments. 3. Propagate the transposes marked for downward propagation to its users, ideally to just before return. 4. Canonicalize out all adjacent cancelling transposes and generalize the remaining transposes to allow for fusing them with nearby ops. --- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/ConvertConvNchwToNhwc.cpp | 564 ++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 7 +- .../compiler/Preprocessing/Common/Passes.td | 7 + .../Preprocessing/Common/test/BUILD.bazel | 1 + .../Preprocessing/Common/test/CMakeLists.txt | 1 + .../Common/test/conv2d_nchw_to_nhwc.mlir | 40 ++ 8 files changed, 621 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index c9acfa23b3e5..eb1f9fabdc0d 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -31,6 +31,7 @@ iree_compiler_cc_library( name = "Transforms", srcs = [ "ConvertConv2DToImg2Col.cpp", + "ConvertConvNchwToNhwc.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", "PassDetail.h", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index fb7b26ff5dcf..b6ca3b0d3d5b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -27,6 +27,7 @@ iree_cc_library( "Passes.h.inc" SRCS "ConvertConv2DToImg2Col.cpp" + "ConvertConvNchwToNhwc.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp new file mode 100644 index 000000000000..5704f18ed466 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp @@ -0,0 +1,564 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-flow-convert-conv-nchw-to-nhwc" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +using TransposeIndices = SmallVector; + +static const StringLiteral transposeEmptyMarker = "__nchw_to_nhwc_init__"; +static const StringLiteral transposePropagateUpMarker = "__nchw_to_nhwc_up__"; +static const StringLiteral transposePropagateDownMarker = + "__nchw_to_nhwc_down__"; + +static TransposeIndices invertIndices(TransposeIndices targetIndices) { + auto rank = targetIndices.size(); + TransposeIndices inverted(rank); + for (auto i : llvm::enumerate(targetIndices)) { + inverted[i.value()] = i.index(); + } + return inverted; +} + +static TransposeIndices getTransposeIndices(linalg::TransposeOp op) { + return llvm::to_vector(op.getPermutation()); +} + +static bool isStaticallyShaped(Value input) { + if (auto inputType = input.getType().dyn_cast()) + return inputType.hasStaticShape(); + return false; +} + +// Get the transpose indices if the given input comes from a transpose and is +// marked to propagate down. +static std::optional getIndicesFromInput(Value input) { + if (!isStaticallyShaped(input)) return std::nullopt; + auto parent = input.getDefiningOp(); + if (parent && parent->hasAttr(transposePropagateDownMarker)) + return getTransposeIndices(parent); + return std::nullopt; +} + +// Get the transpose indices if the given output is used by at least one +// transpose and that transpose is marked to propagate up. Additionally don't +// propagate if there are conflicting transposes. +static std::optional getIndicesFromOutput(Value output) { + if (!isStaticallyShaped(output)) return std::nullopt; + std::optional transposedOut; + if (llvm::all_of(output.getUses(), [&transposedOut](const OpOperand &use) { + auto owner = dyn_cast(use.getOwner()); + if (owner && owner->hasAttr(transposePropagateUpMarker)) { + if (transposedOut.has_value()) { + if (getTransposeIndices(transposedOut.value()) == + getTransposeIndices(owner)) + return true; + return false; + } + transposedOut = owner; + return true; + } + return false; + })) { + if (transposedOut.has_value()) + return getTransposeIndices(transposedOut.value()); + } + return std::nullopt; +} + +// Helper to shuffle vectors according to the transpose indices. +template +static SmallVector shuffleFromIndices(SmallVector unshuffled, + TransposeIndices targetIndices) { + auto rank = unshuffled.size(); + assert(targetIndices.size() == rank && + "Mismatch between number of elements in input and number of indices"); + SmallVector shuffled(rank); + + for (auto i : llvm::enumerate(targetIndices)) { + shuffled[i.index()] = unshuffled[i.value()]; + } + return shuffled; +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static Value createTranspose(PatternRewriter &rewriter, Location loc, + Value input, TransposeIndices targetIndices, + bool propagateUp) { + RankedTensorType inType = input.getType().cast(); + auto elementType = inType.getElementType(); + auto inputShape(inType.getShape()); + + auto outputShape = + shuffleFromIndices(llvm::to_vector(inputShape), targetIndices); + + Value output = + rewriter.create(loc, outputShape, elementType); + output.getDefiningOp()->setAttr(transposeEmptyMarker, rewriter.getUnitAttr()); + + auto transpose = + rewriter.create(loc, input, output, targetIndices); + transpose->setAttr( + propagateUp ? transposePropagateUpMarker : transposePropagateDownMarker, + rewriter.getUnitAttr()); + return transpose.getResults()[0]; +} + +// Supports conv and pooling ops, where pooling ops don't transpose the filter. +template +static LogicalResult convertConvLikeNchwToNhwc(PatternRewriter &rewriter, + ConvOpTy convOp, + bool transposeFilter) { + LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n"); + + Location loc = convOp.getLoc(); + + Value input = convOp.image(); + Value filter = convOp.filter(); + Value output = convOp.getOutputs()[0]; + + if (!isStaticallyShaped(input) || !isStaticallyShaped(output) || + (transposeFilter && !isStaticallyShaped(filter))) { + return failure(); + } + + TransposeIndices NCHWIndices = {0, 2, 3, 1}; + + auto transposedInput = + createTranspose(rewriter, loc, input, NCHWIndices, true); + auto transposedFilter = filter; + if (transposeFilter) { + TransposeIndices FCHWIndices = {2, 3, 1, 0}; + transposedFilter = + createTranspose(rewriter, loc, filter, FCHWIndices, true); + } + auto transposedOutput = + createTranspose(rewriter, loc, output, NCHWIndices, true); + + auto conv = + rewriter + .create(loc, transposedOutput.getType(), + ValueRange{transposedInput, transposedFilter}, + transposedOutput, convOp.getStrides(), + convOp.getDilations()) + .getResult(0); + + auto returnToNCHW = + createTranspose(rewriter, loc, conv, invertIndices(NCHWIndices), false); + + rewriter.replaceOp(convOp, returnToNCHW); + return success(); +} + +namespace { + +/* + * Convolution conversion patterns + */ + +struct ConvertLinalgConvNchwFchw : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, convOp, + true); + } +}; + +struct ConvertLinalgPoolingNchwMax + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::PoolingNchwMaxOp poolOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, poolOp, + false); + } +}; + +struct ConvertLinalgPoolingNchwSum + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::PoolingNchwSumOp poolOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, poolOp, + false); + } +}; + +/* + * Transpose propagation patterns + */ + +struct PropagateThroughTensorPadPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughTensorPadPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + if (propagateUp) { + auto indices = getIndicesFromOutput(padOp.getResult()); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + auto indices = getIndicesFromInput(padOp.getSource()); + if (!indices.has_value()) return failure(); + transposeIndices = invertIndices(indices.value()); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << padOp << "\n"); + + Location loc = padOp.getLoc(); + + auto input = padOp.getSource(); + SmallVector mixedLow = shuffleFromIndices( + padOp.getMixedLowPad(), transposeIndices); + SmallVector mixedHigh = shuffleFromIndices( + padOp.getMixedHighPad(), transposeIndices); + + auto transposedInput = + createTranspose(rewriter, loc, input, transposeIndices, true); + + SmallVector outputShape(padOp.getResultType().getShape()); + SmallVector transposedOutputShape = + shuffleFromIndices(outputShape, transposeIndices); + RankedTensorType transposedOutputType = RankedTensorType::get( + transposedOutputShape, padOp.getResultType().getElementType()); + + auto newPad = rewriter.create(loc, transposedOutputType, + transposedInput, mixedLow, + mixedHigh, padOp.getNofold()); + IRMapping mapper; + padOp.getRegion().cloneInto(&newPad.getRegion(), mapper); + + auto returnToNCHW = createTranspose(rewriter, loc, newPad.getResult(), + invertIndices(transposeIndices), false); + + rewriter.replaceOp(padOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughLinalgFillPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughLinalgFillPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(linalg::FillOp fillOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + if (propagateUp) { + auto indices = getIndicesFromOutput(fillOp.getResult(0)); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + auto indices = getIndicesFromInput(fillOp.value()); + if (!indices.has_value()) return failure(); + transposeIndices = invertIndices(indices.value()); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << fillOp << "\n"); + Location loc = fillOp.getLoc(); + + auto transposedOutput = + createTranspose(rewriter, loc, fillOp.output(), transposeIndices, true); + + auto newTensor = + rewriter.create(loc, fillOp.value(), transposedOutput) + .getResult(0); + + auto returnToNCHW = createTranspose(rewriter, loc, newTensor, + invertIndices(transposeIndices), false); + + rewriter.replaceOp(fillOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughLinalgGenericPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughLinalgGenericPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), + propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + // For now restrict to single results. + if (genericOp.getNumResults() != 1) return failure(); + + if (propagateUp) { + auto indices = getIndicesFromOutput(genericOp.getOutputs()[0]); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + // TODO: Enable directly fusing the transpose with the inputs. + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << genericOp << "\n"); + + Location loc = genericOp.getLoc(); + + auto transposedOutput = genericOp.getOutputs()[0]; + auto indexingMaps = genericOp.getIndexingMapsArray(); + + if (propagateUp) { + transposedOutput = createTranspose(rewriter, loc, transposedOutput, + transposeIndices, true); + + AffineMap outMap = indexingMaps.back(); + SmallVector outExprs(outMap.getResults()); + SmallVector exprs = + shuffleFromIndices(outExprs, transposeIndices); + indexingMaps[indexingMaps.size() - 1] = + AffineMap::get(outMap.getNumDims(), outMap.getNumSymbols(), exprs, + genericOp->getContext()); + } + + SmallVector newInputs; + for (auto input : llvm::enumerate(genericOp.getInputs())) { + newInputs.push_back(input.value()); + } + + SmallVector iteratorTypes = + genericOp.getIteratorTypesArray(); + + auto newGeneric = rewriter.create( + loc, transposedOutput.getType().cast(), newInputs, + transposedOutput, indexingMaps, iteratorTypes); + IRMapping mapper; + genericOp.getRegion().cloneInto(&newGeneric.getRegion(), mapper); + + Value returnToNCHW = newGeneric.getResult(0); + if (propagateUp) { + returnToNCHW = createTranspose(rewriter, loc, returnToNCHW, + invertIndices(transposeIndices), false); + } + + rewriter.replaceOp(genericOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughTensorEmptyPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::EmptyOp emptyOp, + PatternRewriter &rewriter) const override { + if (emptyOp->hasAttr(transposeEmptyMarker)) return failure(); + TransposeIndices transposeIndices; + + auto indices = getIndicesFromOutput(emptyOp.getResult()); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + + LLVM_DEBUG(llvm::dbgs() << "propagating " << emptyOp << "\n"); + + Location loc = emptyOp.getLoc(); + + SmallVector mixedSizes = shuffleFromIndices( + emptyOp.getMixedSizes(), transposeIndices); + + auto newTensor = rewriter.create( + loc, mixedSizes, emptyOp.getType().getElementType()); + auto returnToNCHW = createTranspose(rewriter, loc, newTensor.getResult(), + invertIndices(transposeIndices), false); + + rewriter.replaceOp(emptyOp, returnToNCHW); + return success(); + } +}; + +/* + * Folding away cancelling transposes and generalizing + */ + +// Cancel if this transpose is tagged with a propagating tag and the defining op +// for the input is the inverse of this transpose +struct CancelNCHWToNHWCTransposePattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto transposeIndices = invertIndices(getTransposeIndices(transposeOp)); + + auto parentOp = + transposeOp->getOperand(0).getDefiningOp(); + if (parentOp) { + if (getTransposeIndices(parentOp) == transposeIndices) { + rewriter.replaceOp(transposeOp, parentOp->getOperand(0)); + return success(); + } + } + + return failure(); + } +}; + +struct GeneralizeTransposeOpPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + if (transposeOp->hasAttr(transposePropagateUpMarker) || + transposeOp->hasAttr(transposePropagateDownMarker)) { + auto context = rewriter.getContext(); + auto rank = + transposeOp.getResultTypes()[0].cast().getRank(); + + auto transposeIndices = getTransposeIndices(transposeOp); + + SmallVector idExprs; + for (auto i = 0; i < rank; i++) + idExprs.push_back(getAffineDimExpr(i, context)); + + SmallVector swapExprs = + shuffleFromIndices(idExprs, transposeIndices); + + SmallVector indexingMaps = { + AffineMap::get(rank, 0, idExprs, context), + AffineMap::get(rank, 0, swapExprs, context)}; + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultTypes()[0], + transposeOp.getOperand(0), transposeOp.getOperand(1), indexingMaps, + iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); + return success(); + } + return failure(); + } +}; + +// The high level strategy for this pass is as follows: +// 1. Do the conversions for all conv_nchw_fchw ops (and pooling ops) and +// wrap the converted convolutions in transposes. Each transpose is tagged +// to indicate which direction the transpose should propagate through the +// graph. +// 2. Traverse the ops in the function in reverse to propagate transposes +// marked for upwards propagation to their parents. Ideally just before ops +// such as arith.constant or function arguments. +// 3. Propagate the transposes marked for downward propagation to its users, +// ideally to just before return. +// 4. Canonicalize out all adjacent cancelling transposes and generalize the +// remaining transposes to allow for fusing them with nearby ops. +struct ConvertConvNchwToNhwcPass + : public ConvertConvNchwToNhwcBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + Operation *funcOp = getOperation(); + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Propagate transposes up the graph. + { + SmallVector ops; + funcOp->walk([&](Operation *op) { ops.push_back(op); }); + + RewritePatternSet patterns(context); + patterns.insert(context, true); + patterns.insert(context); + patterns.insert(context, true); + patterns.insert(context, true); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + SmallVector reverseOps(llvm::reverse(ops)); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::AnyOp; + (void)applyOpPatternsAndFold(reverseOps, frozenPatterns, config); + } + + // Propagate transposes down the graph. + { + RewritePatternSet patterns(context); + patterns.insert(context, false); + patterns.insert(context, false); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + // Cancel out transposes. + { + RewritePatternSet patterns(context); + patterns.insert(context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + // Generalize remaining transposes to allow fusion with other ops. + { + RewritePatternSet patterns(context); + patterns.insert(context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + } +}; + +} // namespace + +std::unique_ptr> +createConvertConvNchwToNhwcPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index fcf5ddf86499..6c897172d7d5 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -10,6 +10,7 @@ #include #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -19,11 +20,15 @@ namespace mlir::iree_compiler::Preprocessing { /// using im2col tranformation. std::unique_ptr createConvertConv2DToImg2ColPass(); +// Creates a pass to convert linalg NCHW Convolutions to NHWC. +std::unique_ptr> +createConvertConvNchwToNhwcPass(); + /// Moves the body of the entire function into a single dispatch. std::unique_ptr> createMakeSingleDispatchForFunctionPass(); -/// A pass to pad linalg ops to the next integer multiple of `paddingSize`. +// A pass to pad linalg ops to the next integer multiple of `paddingSize`. std::unique_ptr createPadLinalgOpsToIntegerMultiplePass(); //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index ae67e754cfcc..ed8b9194aecf 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -15,6 +15,13 @@ def ConvertConv2DToImg2Col : let constructor = "mlir::iree_compiler::Preprocessing::createConvertConv2DToImg2ColPass()"; } +def ConvertConvNchwToNhwc : + InterfacePass<"iree-flow-convert-conv-nchw-to-nhwc", "mlir::FunctionOpInterface"> { + let summary = "Convert linalg NCHW Convolutions to NHWC"; + let constructor = + "mlir::iree_compiler::IREE::createConvertConvNchwToNhwcPass()"; +} + def MakeSingleDispatchForFunction : Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> { let summary = "Convert entire function into a single dispatch"; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index d11313bc5a98..044ffc329326 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "conv2d_nchw_to_nhwc.mlir", "conv2d_to_img2col.mlir", "make_single_dispatch_for_function.mlir", "pad_linalg_ops.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index 19425cb4944d..c47bc7b0ea22 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "conv2d_nchw_to_nhwc.mlir" "conv2d_to_img2col.mlir" "make_single_dispatch_for_function.mlir" "pad_linalg_ops.mlir" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir new file mode 100644 index 000000000000..b7ab1af35a69 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir @@ -0,0 +1,40 @@ +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))" %s | FileCheck %s + +func.func @batch_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1, d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> +// CHECK: @batch_conv +// CHECK: %[[INPUT:.+]]: tensor<8x4x16x16xf32> +// CHECK: %[[FILTER:.+]]: tensor<16x4x3x3xf32> +// CHECK: %[[OUTPUT:.+]]: tensor<8x16x14x14xf32> +// CHECK: %[[INIT_INPUT_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x16x16x4xf32> +// CHECK: %[[TRANSPOSED_INPUT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>) outs(%[[INIT_INPUT_TRANSPOSE]] : tensor<8x16x16x4xf32>) +// CHECK: %[[INIT_FILTER_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<3x3x4x16xf32> +// CHECK: %[[TRANSPOSED_FILTER:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP2]] +// CHECK-SAME: ins(%[[FILTER]] : tensor<16x4x3x3xf32>) outs(%[[INIT_FILTER_TRANSPOSE]] : tensor<3x3x4x16xf32>) +// CHECK: %[[INIT_OUTPUT_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x14x14x16xf32> +// CHECK: %[[TRANSPOSED_OUTPUT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK-SAME: ins(%[[OUTPUT]] : tensor<8x16x14x14xf32>) outs(%[[INIT_OUTPUT_TRANSPOSE]] : tensor<8x14x14x16xf32>) +// CHECK: %[[TRANSPOSED_RESULT:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_FILTER]] : tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>) outs(%[[TRANSPOSED_OUTPUT]] : tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> +// CHECK: %[[INIT_RESULT:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x16x14x14xf32> +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP3]] +// CHECK-SAME: ins(%[[TRANSPOSED_RESULT]] : tensor<8x14x14x16xf32>) outs(%[[INIT_RESULT]] : tensor<8x16x14x14xf32>) +// CHECK: return %[[RESULT]] : tensor<8x16x14x14xf32> From aed33788a377d587562d9f25dcc59f5c55e0de13 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Fri, 28 Oct 2022 17:45:51 -0700 Subject: [PATCH 04/38] [codegen][spirv] Pack/transpose matrix B for better coop mmma --- .../Flow/Transforms/FormDispatchRegions.cpp | 9 +- .../Dialect/Flow/Transforms/Passes.cpp | 6 + .../compiler/Preprocessing/Common/BUILD.bazel | 4 +- .../Preprocessing/Common/CMakeLists.txt | 2 + .../Common/ConvertLinalgMatmulToMmt.cpp | 119 ++++++++++++++ .../Common/GeneralizeAndFuse.cpp | 148 ++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 6 + .../compiler/Preprocessing/Common/Passes.td | 12 ++ 8 files changed, 304 insertions(+), 2 deletions(-) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index 11fba41d9a64..cd4a8987d59f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -662,7 +662,14 @@ isFusableWithProducer(OpOperand &operand, } auto consumerLinalgOp = cast(consumer); - if (!consumerLinalgOp.isDpsInit(&operand)) { + if (consumerLinalgOp.isDpsInput(&operand)) { + // TODO: Add some marker on transpose and MatmulOp to indicate mmt. + bool fuseTransposeAndMatmul = + isa(consumer) && isa(producer); + if (fuseTransposeAndMatmul) { + return true; + } + } else if (!consumerLinalgOp.isDpsInit(&operand)) { return false; } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 5ca39ceb9a18..ca61270289e9 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -84,6 +84,12 @@ static llvm::cl::opt clDispatchGenerateWorkloadRegion( "iree-flow-dispatch-generate-workload-region", llvm::cl::desc("Generate the workload region."), llvm::cl::init(true)); + +static llvm::cl::opt clEnableTransposeMatmulLayout( + "iree-flow-enable-transpose-matmul-layout", + llvm::cl::desc("Enable transposing the B matrix for matmuls."), + llvm::cl::init(false)); + static llvm::cl::opt clNormalizeInputIndexingMap( "iree-flow-normalize-input-indexing-map", llvm::cl::desc("Enable normalizing input indexing map to identity."), diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index eb1f9fabdc0d..f8582d347074 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -32,8 +32,10 @@ iree_compiler_cc_library( srcs = [ "ConvertConv2DToImg2Col.cpp", "ConvertConvNchwToNhwc.cpp", + "ConvertLinalgMatmulToMmt.cpp", + "GeneralizeAndFuse.cpp", "MakeSingleDispatchForFunction.cpp", - "PadLinalgOps.cpp", + "PadLinalgOps.cpp", "PassDetail.h", "Passes.cpp", ], diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index b6ca3b0d3d5b..b5bd22be2d72 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -28,6 +28,8 @@ iree_cc_library( SRCS "ConvertConv2DToImg2Col.cpp" "ConvertConvNchwToNhwc.cpp" + "ConvertLinalgMatmulToMmt.cpp" + "GeneralizeAndFuse.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp new file mode 100644 index 000000000000..f22e55cf92ac --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp @@ -0,0 +1,119 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +namespace { + +// Converts linalg.matmul to an linalg.transpose + linalg.matmul. +// Such that matrix B layout changes to col major. +class LinalgMatmulOpToLinalgMmtPattern final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + Location loc = matmulOp.getLoc(); + Value lhs = matmulOp.getDpsInputOperand(0)->get(); + Value rhs = matmulOp.getDpsInputOperand(1)->get(); + Value acc = matmulOp.getDpsInitOperand(0)->get(); + if (dyn_cast(rhs.getDefiningOp())) { + return failure(); + } + auto rhsType = rhs.getType().cast(); + auto rhsShape = rhsType.getShape(); + auto rhsElemType = rhsType.getElementType(); + SmallVector transposedRhsShape = {rhsShape[1], rhsShape[0]}; + + // GenericOp + int64_t nloops = rhsShape.size(); + AffineExpr mDim, nDim; + bindDims(getContext(), mDim, nDim); + auto inputMap = AffineMap::get(2, 0, {mDim, nDim}, getContext()); + auto packedMap = AffineMap::get(2, 0, {nDim, mDim}, getContext()); + SmallVector indexingMaps = {inputMap, packedMap}; + + Value transposedRhs = + rewriter.create(loc, transposedRhsShape, rhsElemType); + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + Value packedRhs = + rewriter + .create( + loc, transposedRhs.getType(), + /*inputs=*/rhs, /*outputs=*/transposedRhs, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }) + .getResult(0); + + // TransposeOp + Value initOp = rewriter.create(loc, rhsShape, rhsElemType); + SmallVector transposedPerm = {1, 0}; + Value transposePackedRhs = + rewriter + .create(loc, packedRhs, initOp, transposedPerm) + .getResults()[0]; + + // MatmulOp + Value packedMatmul = + rewriter + .create(loc, matmulOp.getResult(0).getType(), + ArrayRef{lhs, transposePackedRhs}, + ArrayRef{acc}) + .getResult(0); + rewriter.replaceOp(matmulOp, packedMatmul); + return success(); + } +}; + +struct ConvertLinalgMatmulToMmtPass + : public ConvertLinalgMatmulToMmtBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + // Main pattern. + { + RewritePatternSet patterns(&getContext()); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + } +}; +} // namespace + +std::unique_ptr createConvertLinalgMatmulToMmtPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp new file mode 100644 index 000000000000..8a4e7bee31e3 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp @@ -0,0 +1,148 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +namespace { + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { + return linalg::isaContractionOpInterface(linalgOp) && + llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops()); +} + +//===----------------------------------------------------------------------===// +// Generalize and fusion patterns. +//===----------------------------------------------------------------------===// + +struct GeneralizeAndFusePass + : public GeneralizeAndFuseBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + template + class GeneralizeTargetNamedOpPattern final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LinalgOpType linalgOp, + PatternRewriter &rewriter) const override { + // TODO: Check consumer is transposeOp. + // TODO: Generalize transpos + FailureOr genericOp = + linalg::generalizeNamedOp(rewriter, linalgOp); + if (failed(genericOp)) return failure(); + return success(); + } + }; + + class FuseMatmulAndTranspose final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + // Inspo: + // https://github.com/llvm/llvm-project/blob/4f1c12425179608298dc39f5524ba2612609b5e4/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp + LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, + PatternRewriter &rewriter) const override { + const unsigned rhsId = 1; + if (!isMatmulOrBatchMatmul(linalgOp)) return failure(); + Value rhs = linalgOp.getDpsInputOperand(rhsId)->get(); + auto transposeOp = dyn_cast(rhs.getDefiningOp()); + if (!transposeOp) return failure(); + auto perm = transposeOp.getPermutation(); + auto indexingMaps = linalgOp.getIndexingMaps(); + auto rhsMap = indexingMaps[rhsId].cast().getValue(); + int64_t rank = perm.size(); + if (rhsMap.getNumResults() != rank) return failure(); + SmallVector exprs; + for (auto dim_id : perm) { + exprs.push_back(rhsMap.getResult(dim_id)); + } + AffineMap transposedRhsMap = + AffineMap::get(rhsMap.getNumDims(), 0, exprs, getContext()); + + // TODO: Fold transposeOp as transposed indexing for matmulOp. + // Generate a map set. + auto lhsMap = indexingMaps[0].cast().getValue(); + auto accMap = indexingMaps[2].cast().getValue(); + SmallVector newIndexingMaps = {lhsMap, transposedRhsMap, + accMap}; + + // Generate new list of args. + Value newRhs = transposeOp.getDpsInputOperand(0)->get(); + Value lhs = linalgOp.getDpsInputOperand(0)->get(); + Value acc = linalgOp.getDpsInitOperand(0)->get(); + SmallVector inputs = {lhs, newRhs}; + + // Generate a new genericOp. + linalg::GenericOp genericOp = rewriter.create( + linalgOp.getLoc(), linalgOp.getResultTypes(), /*inputs*/ inputs, + /*outputs*/ acc, newIndexingMaps, linalgOp.getIteratorTypesArray()); + // Block consumerBlock = linalgOp->getRegion(0).front(); + // genericOp.getRegion().push_back(consumerBlock); + // llvm::outs()<<"new op + // regions:"<getNumRegions()<<"\n"; + // llvm::outs()<<"new op + // regions:"<getNumRegions()<<"\n"; + // llvm::outs()<<"new op + // blocks:"<getNumRegions()<<"\n"; + // llvm::outs()<<"old op + // blocks:"<getRegion(0), genericOp.getRegion(), + genericOp.getRegion().begin()); + rewriter.replaceOp(linalgOp, genericOp->getResults()); + return success(); + } + }; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + // Main pattern. + // Generalize + Fuse pattern. + { + RewritePatternSet patterns(&getContext()); + patterns.insert, + FuseMatmulAndTranspose>(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + } +}; +} // namespace + +std::unique_ptr createGeneralizeAndFusePass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index 6c897172d7d5..4831ff8d1ce0 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -24,6 +24,12 @@ std::unique_ptr createConvertConv2DToImg2ColPass(); std::unique_ptr> createConvertConvNchwToNhwcPass(); +// Pass to convert a linalg.matmul into linalg.transpose + linalg.matmul. +std::unique_ptr createConvertLinalgMatmulToMmtPass(); + +// Generalizes named op and try to fuse them +std::unique_ptr createGeneralizeAndFusePass(); + /// Moves the body of the entire function into a single dispatch. std::unique_ptr> createMakeSingleDispatchForFunctionPass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index ed8b9194aecf..ca1c9913d21e 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -22,6 +22,18 @@ def ConvertConvNchwToNhwc : "mlir::iree_compiler::IREE::createConvertConvNchwToNhwcPass()"; } +def ConvertLinalgMatmulToMmt : + Pass<"iree-flow-convert-linalg-matmul-to-mmt", ""> { + let summary = "Convert linalg.matmul to linalg.transpose + linalg.matmul"; + let constructor = "mlir::iree_compiler::IREE::createConvertLinalgMatmulToMmtPass()"; +} + +def GeneralizeAndFuse : + Pass<"iree-flow-generalize-and-fuse", ""> { + let summary = "Generalizes named op and try to fuse them."; + let constructor = "mlir::iree_compiler::IREE::createGeneralizeAndFusePass()"; +} + def MakeSingleDispatchForFunction : Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> { let summary = "Convert entire function into a single dispatch"; From 7f8e66740d3c222b4e2914995493d4938ce0f165 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 12 Dec 2022 13:24:05 -0500 Subject: [PATCH 05/38] [winograd] Add winograd convolution attribute control as iree_winograd_conv --- .../Dialect/LinalgExt/Passes/Passes.h | 2 +- .../Passes/ConvertConv2DToWinograd.cpp | 28 +++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h index 9627486a62cc..54f9e947d36d 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h @@ -115,7 +115,7 @@ createTileAndDecomposeWinogradTransformPass(); // Creates a pass to convert linalg convolution ops into a sequence of // linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd // tranformation. -std::unique_ptr createConvertConv2DToWinogradPass(); +std::unique_ptr createConvertConv2DToWinogradPass(bool forceWinograd = false); // Transform dialect version of tile and decompose attention wrapper. void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp, diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp index 0680be4d8e8e..bece710a76bd 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp @@ -30,6 +30,8 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { +static const char winogradAttr[] = "iree_winograd_conv"; + static inline int index(int y, int x, int dimy, int dimx) { return (x + dimx * y); } @@ -134,10 +136,16 @@ template class FoldWinogradFilterTransform final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + FoldWinogradFilterTransform(MLIRContext *context, bool force) + : OpRewritePattern(context, force), forceWinograd(force) {} LogicalResult matchAndRewrite(ConvOp convOp, PatternRewriter &rewriter) const override { + // Attribute control unless forced. + if (!forceWinograd && !convOp->hasAttr(winogradAttr)) + return failure(); + bool isNchw; if (!isValidConv2d(convOp, isNchw)) return failure(); @@ -187,6 +195,8 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { rewriter.replaceOpWithNewOp(constOp, foldedKernelAttr); return success(); } + private: + bool forceWinograd; }; } // namespace @@ -283,10 +293,16 @@ template class ConvertConvToWinograd final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + ConvertConvToWinograd(MLIRContext *context, bool force) + : OpRewritePattern(context, force), forceWinograd(force) {} LogicalResult matchAndRewrite(ConvOp convOp, PatternRewriter &rewriter) const override { + // Attribute control unless forced. + if (!forceWinograd && !convOp->hasAttr(winogradAttr)) + return failure(); + bool isNchw; if (!isValidConv2d(convOp, isNchw)) return failure(); @@ -416,10 +432,14 @@ class ConvertConvToWinograd final : public OpRewritePattern { result.replaceAllUsesWith(winogradOutput); return success(); } + private: + bool forceWinograd; }; struct ConvertConv2DToWinogradPass : ConvertConv2DToWinogradBase { + public: + ConvertConv2DToWinogradPass(bool forceWinograd) : forceWinograd(forceWinograd) {} void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); @@ -430,18 +450,20 @@ struct ConvertConv2DToWinogradPass patterns.insert, FoldWinogradFilterTransform, ConvertConvToWinograd, - ConvertConvToWinograd>(context); + ConvertConvToWinograd>(context, forceWinograd); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } + private: + bool forceWinograd; }; } // namespace -std::unique_ptr createConvertConv2DToWinogradPass() { - return std::make_unique(); +std::unique_ptr createConvertConv2DToWinogradPass(bool forceWinograd) { + return std::make_unique(forceWinograd); } } // namespace LinalgExt From 8502a65bd3bcc3b9e95204adcfe336ff3c5b3d50 Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 13 Dec 2022 08:59:56 -0800 Subject: [PATCH 06/38] [Winograd] Winograd improvements - Speedup filter transform folding - Add points for 4x4, switch to that tile size - Move winograd after im2col + padding, in im2col do not touch conv if it has been marked as winograd -remove prints/chrono and adjust Attribute rawKernelAttr for windows by Quinn Co-authored-by: Quinn Dawkins --- .../Common/ConvertConv2DToImg2Col.cpp | 8 +++ .../LinalgExt/Utils/WinogradConstants.h | 56 +++++++++++++++++++ .../Passes/ConvertConv2DToWinograd.cpp | 45 +++++++++++---- .../Passes/TileAndDecomposeWinogradPass.cpp | 27 +++++++-- 4 files changed, 122 insertions(+), 14 deletions(-) diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp index 8bdb0120abdf..c06717a68959 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp @@ -21,6 +21,8 @@ namespace mlir::iree_compiler::Preprocessing { +static const char winogradAttr[] = "iree_winograd_conv"; + static bool hasAllOneValues(DenseIntElementsAttr attr) { return llvm::all_of( attr, [](APInt element) { return element.getSExtValue() == 1; }); @@ -94,6 +96,9 @@ class ConvertConv2DNhwcHwcf final if (!hasAllOneValues(convOp.getDilations())) return failure(); + // Ignore if marked as Winograd convolution + if (convOp->hasAttr(winogradAttr)) return failure(); + Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; @@ -403,6 +408,9 @@ class ConvertConv2DNchwFchw final if (!hasAllOneValues(convOp.getDilations())) return failure(); + // Ignore if marked as Winograd convolution + if (convOp->hasAttr(winogradAttr)) return failure(); + Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h index 77dbb09135b2..62d93d14e5aa 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h @@ -82,6 +82,62 @@ const float A_6x6_3x3[] = { // clang-format on +//===----------------------------------------------------------------------===// +// Output tile size = 4, Kernel size = 3 +//===----------------------------------------------------------------------===// +// These constants were obtained from this paper: +// +// Lavin, A. et al (2016) Fast Algorithms for Convolution Neural Networks. +// https://openaccess.thecvf.com/content_cvpr_2016/papers/Lavin_Fast_Algorithms_for_CVPR_2016_paper.pdf +// + +// clang-format off + +const float BT_4x4_3x3[] = { + 4, 0, -5, 0, 1, 0, + 0, -4, -4, 1, 1, 0, + 0, 4, -4, -1, 1, 0, + 0, -2, -1, 2, 1, 0, + 0, 2, -1, -2, 1, 0, + 0, 4, 0, -5, 0, 1 +}; + +const float B_4x4_3x3[] = { + 4, 0, 0, 0, 0, 0, + 0, -4, 4, -2, 2, 4, + -5, -4, -4, -1, -1, 0, + 0, 1, -1, 2, -2, -5, + 1, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 1 +}; + +const float G_4x4_3x3[] = { + 1./4., 0, 0, + -1./6., -1./6., -1./6., + -1./6., 1./6., -1./6., + 1./24., 1./12., 1./6., + 1./24., -1./12., 1./6., + 0, 0, 1 +}; + +const float AT_4x4_3x3[] = { + 1, 1, 1, 1, 1, 0, + 0, 1, -1, 2, -2, 0, + 0, 1, 1, 4, 4, 0, + 0, 1, -1, 8, -8, 1 +}; + +const float A_4x4_3x3[] = { + 1, 0, 0, 0, + 1, 1, 1, 1, + 1, -1, 1, -1, + 1, 2, 4, 8, + 1, -2, 4, -8, + 0, 0, 0, 1 +}; + +// clang-format on + } // namespace Winograd } // namespace LinalgExt } // namespace IREE diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp index bece710a76bd..3a994f1b048f 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp @@ -24,6 +24,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" +#include +#include namespace mlir { namespace iree_compiler { @@ -47,7 +49,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { // TODO: Make this a user-settable parameter once we have support // for more tile sizes -static constexpr int64_t outputTileSize = 6; +static constexpr int64_t outputTileSize = 4; /// This function computes the Winograd filter transform when /// the filter is known to be a constant. Specifically, this @@ -72,30 +74,50 @@ foldFilterTransform(ArrayRef shape, int64_t inputTileSize, const int &kw = isNchw ? shape[3] : shape[1]; const int &ic = isNchw ? shape[1] : shape[2]; const int &oc = isNchw ? shape[0] : shape[3]; + //printf("Folding filter with kh = %d, kw = %d, ic = %d, oc = %d\n", kh, kw, ic, oc); const int64_t numElements = inputTileSize * inputTileSize * ic * oc; + float *alloc{nullptr}; + if (!isSplat) { + alloc = (float *) malloc(kh * kw * ic * oc * sizeof(float)); + for (int d2 = 0; d2 < ic; d2++) { + for (int d3 = 0; d3 < oc; d3++) { + for (int d4 = 0; d4 < kernelSize; d4++) { + for (int d5 = 0; d5 < kernelSize; d5++) { + int idx; + if (!isNchw) { + idx = index(d4, d5, d2, d3, kh, kw, ic, oc); + } else { + idx = index(d3, d2, d4, d5, oc, ic, kh, kw); + } + alloc[idx] = input[idx].convertToFloat(); + } + } + } + } + } SmallVector output(numElements, APFloat(0.0f)); for (int d0 = 0; d0 < inputTileSize; d0++) { for (int d1 = 0; d1 < inputTileSize; d1++) { for (int d2 = 0; d2 < ic; d2++) { for (int d3 = 0; d3 < oc; d3++) { - APFloat accum(0.0f); + float accum(0.0f); for (int d4 = 0; d4 < kernelSize; d4++) { for (int d5 = 0; d5 < kernelSize; d5++) { - APFloat ival(splatValue); + float ival{splatValue}; if (!isSplat) { if (!isNchw) { - ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)]; + ival = alloc[index(d4, d5, d2, d3, kh, kw, ic, oc)]; } else { - ival = input[index(d3, d2, d4, d5, oc, ic, kh, kw)]; + ival = alloc[index(d3, d2, d4, d5, oc, ic, kh, kw)]; } } int idx0 = index(d0, d4, inputTileSize, kernelSize); int idx1 = index(d1, d5, inputTileSize, kernelSize); - accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]); + accum = accum + G[idx0] * ival * G[idx1]; } } int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc); - output[odx] = accum; + output[odx] = APFloat(accum); if (floatType.isF16()) { bool losesInfo; output[odx].convert(APFloat::IEEEhalf(), @@ -105,6 +127,7 @@ foldFilterTransform(ArrayRef shape, int64_t inputTileSize, } } } + if (alloc) free(alloc); return DenseElementsAttr::get(outputType, output); } @@ -165,10 +188,11 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { const int64_t kernelSize = kh; const int64_t inputTileSize = outputTileSize + kernelSize - 1; - DenseIntOrFPElementsAttr kernelAttr; - if (!matchPattern(kernel, m_Constant(&kernelAttr))) { + Attribute rawKernelAttr; + if (!matchPattern(kernel, m_Constant(&rawKernelAttr)) || !isa(rawKernelAttr)) { return failure(); } + DenseIntOrFPElementsAttr kernelAttr = cast(rawKernelAttr); Operation *constOp = kernel.getDefiningOp(); ShapedType type = constOp->getResult(0).getType().cast(); @@ -190,8 +214,9 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { auto resultType = RankedTensorType::get(resultShape, elemType); auto foldedKernelAttr = foldFilterTransform(shape, inputTileSize, kernelSize, resultType, - IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat, + IREE::LinalgExt::Winograd::G_4x4_3x3, isSplat, splatValue, nonSplatValues, elemType, isNchw); + rewriter.replaceOpWithNewOp(constOp, foldedKernelAttr); return success(); } diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp index 5e1bd34a273e..379e88f25276 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp @@ -210,10 +210,20 @@ static LogicalResult decomposeTiledWinogradInputTransformOp( loc, rewriter.getZeroAttr(elementType)); Value scratch = rewriter.create(loc, inputTileSquare, elementType); + const float *BT{nullptr}; const float *B{nullptr}; - B = IREE::LinalgExt::Winograd::B_6x6_3x3; - BT = IREE::LinalgExt::Winograd::BT_6x6_3x3; + const int64_t outputTileSize = + tiledWinogradInputTransformOp.getOutputTileSize(); + switch (outputTileSize) { + case 4: + B = IREE::LinalgExt::Winograd::B_4x4_3x3; + BT = IREE::LinalgExt::Winograd::BT_4x4_3x3; + break; + default: + B = IREE::LinalgExt::Winograd::B_6x6_3x3; + BT = IREE::LinalgExt::Winograd::BT_6x6_3x3; + } Value BTV = IREE::LinalgExt::createValueFrom2DConstant( BT, inputTileSize, inputTileSize, loc, rewriter); Value BV = IREE::LinalgExt::createValueFrom2DConstant( @@ -435,14 +445,23 @@ static LogicalResult decomposeTiledWinogradOutputTransformOp( "output operand expected to have rank-2"); ShapedType outputType = tiledWinogradOutputTransformOp.getOutputOperandType(); Type elementType = outputType.getElementType(); + const float *AT{nullptr}; const float *A{nullptr}; - A = IREE::LinalgExt::Winograd::A_6x6_3x3; - AT = IREE::LinalgExt::Winograd::AT_6x6_3x3; const int64_t inputTileSize = tiledWinogradOutputTransformOp.getInputTileSize(); const int64_t outputTileSize = tiledWinogradOutputTransformOp.getOutputTileSize(); + switch (outputTileSize) { + case 4: + A = IREE::LinalgExt::Winograd::A_4x4_3x3; + AT = IREE::LinalgExt::Winograd::AT_4x4_3x3; + break; + default: + A = IREE::LinalgExt::Winograd::A_6x6_3x3; + AT = IREE::LinalgExt::Winograd::AT_6x6_3x3; + } + /// The two values below are the transpose(A) [ATV] /// and A [AV] constant matrices that convert the output /// tile from the Winograd domain to the original domain. From 44a85322aaeba7969b9c398812738f9ac53f2701 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 23 Jan 2023 02:24:11 -0500 Subject: [PATCH 07/38] [API] Expose iree-opt in python for applying flow preprocessing passes --- compiler/bindings/python/CMakeLists.txt | 7 + compiler/bindings/python/IREEOptTool.c | 9 ++ .../python/iree/compiler/tools/binaries.py | 2 + .../python/iree/compiler/tools/core.py | 122 ++++++++++++++++++ 4 files changed, 140 insertions(+) create mode 100644 compiler/bindings/python/IREEOptTool.c diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index e47ff50ff964..6817758ae010 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -210,6 +210,13 @@ add_iree_compiler_busybox_tool( IREECompileTool.c ) +add_iree_compiler_busybox_tool( + IREECompilerIREEOptTool + OUTPUT_NAME iree-opt + SRCS + IREEOptTool.c +) + if(TARGET lld) add_iree_compiler_busybox_tool( IREECompilerLldTool diff --git a/compiler/bindings/python/IREEOptTool.c b/compiler/bindings/python/IREEOptTool.c new file mode 100644 index 000000000000..5c3f24133251 --- /dev/null +++ b/compiler/bindings/python/IREEOptTool.c @@ -0,0 +1,9 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/tool_entry_points_api.h" + +int main(int argc, char **argv) { return ireeOptRunMain(argc, argv); } diff --git a/compiler/bindings/python/iree/compiler/tools/binaries.py b/compiler/bindings/python/iree/compiler/tools/binaries.py index 7b8e592bf297..0ef1e604ebb9 100644 --- a/compiler/bindings/python/iree/compiler/tools/binaries.py +++ b/compiler/bindings/python/iree/compiler/tools/binaries.py @@ -30,6 +30,7 @@ _BUILTIN_TOOLS = [ "iree-compile", + "iree-opt", "iree-lld", ] @@ -42,6 +43,7 @@ # options. "iree-compile": "iree.tools.core", "iree-lld": "iree.tools.core", + "iree-opt": "iree.tools.core", "iree-import-tflite": "iree.tools.tflite", "iree-import-tf": "iree.tools.tf", } diff --git a/compiler/bindings/python/iree/compiler/tools/core.py b/compiler/bindings/python/iree/compiler/tools/core.py index 4eb7bc66229a..cd9829ff0451 100644 --- a/compiler/bindings/python/iree/compiler/tools/core.py +++ b/compiler/bindings/python/iree/compiler/tools/core.py @@ -25,6 +25,8 @@ "CompilerOptions", "InputType", "OutputFormat", + "preprocess_file", + "preprocess_str", ] # Default testing backend for invoking the compiler. @@ -318,3 +320,123 @@ def query_available_targets(): target_backends = [target for target in target_backends if target] return target_backends + + +# Preprocessing for SHARK (for now simply exposes iree-opt) + + +def build_opt_command_line( + input_file: str, tfs: TempFileSaver, options: CompilerOptions +) -> List[str]: + """Builds a command line for applying specified patterns. + + Args: + input_file: The input file name. + tfs: TempFileSaver. + options: Compiler options. + Returns: + List of strings of command line. + """ + iree_opt = find_tool("iree-opt") + cl = [ + iree_opt, + input_file, + ] + + # Output file. + if options.output_file: + cl.append(f"-o={options.output_file}") + + # Tool paths. + lld_path = find_tool("iree-lld") + cl.append(f"--iree-llvm-embedded-linker-path={lld_path}") + + crash_reproducer_path = tfs.alloc_optional( + "core-reproducer.mlir", export_as=options.crash_reproducer_path + ) + if crash_reproducer_path: + cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") + + cl.extend(options.extra_args) + print(cl) + return cl + + +def preprocess_file(input_file: str, **kwargs): + """Invokes iree-opt on an input file. + + Args: + input_file: File containing MLIR assembly to compile. + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_opt_command_line(input_file, tfs, options) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl) + if options.output_file: + return None + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result + + +def preprocess_str(input_str: Union[str, bytes], **kwargs): + """Invokes the IREE compiler with an input string. + + Args: + input_str: MLIR assembly to parse/compile (str or bytes). + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + retained_input_file = tfs.alloc_optional("core-input.mlir") + if retained_input_file: + with open( + retained_input_file, "wt" if isinstance(input_str, str) else "wb" + ) as f: + f.write(input_str) + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_opt_command_line("-", tfs, options) + input_bytes = ( + input_str.encode("utf-8") if isinstance(input_str, str) else input_str + ) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl, immediate_input=input_bytes) + if options.output_file: + return None + + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result From fbe9cba845c64fd1a9430e34aeb7797bdee97980 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 28 Mar 2023 11:37:38 -0700 Subject: [PATCH 08/38] [COMPILER] Add a plugin to split MLIR functions Add pass to insert markers for function bisecting. Add pass to outline marked operation ranges into separate functions. --- .../iree_compiler_plugin_group.cmake | 7 + experimental/split_mlir/lit.cfg.py | 41 +++ experimental/split_mlir/src/CMakeLists.txt | 20 ++ .../split_mlir/src/iree/CMakeLists.txt | 7 + .../src/iree/split_mlir/CMakeLists.txt | 52 +++ .../src/iree/split_mlir/MarkBisectPassImpl.h | 87 +++++ .../split_mlir/OutlineFunctionsPassImpl.h | 299 ++++++++++++++++++ .../split_mlir/src/iree/split_mlir/Passes.cpp | 8 + .../split_mlir/src/iree/split_mlir/Passes.h | 37 +++ .../split_mlir/src/iree/split_mlir/Passes.td | 79 +++++ .../iree/split_mlir/PluginRegistration.cpp | 32 ++ .../src/iree/split_mlir/test/CMakeLists.txt | 27 ++ .../split_mlir/test/function_outlining.mlir | 102 ++++++ .../src/iree/split_mlir/test/mark_bisect.mlir | 68 ++++ 14 files changed, 866 insertions(+) create mode 100644 experimental/split_mlir/iree_compiler_plugin_group.cmake create mode 100644 experimental/split_mlir/lit.cfg.py create mode 100644 experimental/split_mlir/src/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/Passes.cpp create mode 100644 experimental/split_mlir/src/iree/split_mlir/Passes.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/Passes.td create mode 100644 experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir diff --git a/experimental/split_mlir/iree_compiler_plugin_group.cmake b/experimental/split_mlir/iree_compiler_plugin_group.cmake new file mode 100644 index 000000000000..5ec79cd63392 --- /dev/null +++ b/experimental/split_mlir/iree_compiler_plugin_group.cmake @@ -0,0 +1,7 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src split_mlir) diff --git a/experimental/split_mlir/lit.cfg.py b/experimental/split_mlir/lit.cfg.py new file mode 100644 index 000000000000..eba45391c230 --- /dev/null +++ b/experimental/split_mlir/lit.cfg.py @@ -0,0 +1,41 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Lint for undefined variables is disabled as config is not defined inside this +# file, instead config is injected by way of evaluating runlit.cfg.py from +# runlit.site.cfg.py which in turn is evaluated by lit.py. +# pylint: disable=undefined-variable + +import os +import tempfile + +import lit.formats + +config.name = "IREE" +config.suffixes = [".mlir", ".txt"] +config.test_format = lit.formats.ShTest(execute_external=True) + +# Forward all IREE environment variables, as well as some passthroughs. +# Note: env vars are case-insensitive on Windows, so check matches carefully. +# https://stackoverflow.com/q/7797269 +passthrough_env_vars = [ + # The Vulkan loader uses this + "VK_ICD_FILENAMES", + # WindowsLinkerTool uses these from vcvarsall + "VCTOOLSINSTALLDIR", + "UNIVERSALCRTSDKDIR", + "UCRTVERSION" +] +config.environment.update({ + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars +}) + +# Use the most preferred temp directory. +config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or + os.environ.get("TEST_TMPDIR") or + os.path.join(tempfile.gettempdir(), "lit")) diff --git a/experimental/split_mlir/src/CMakeLists.txt b/experimental/split_mlir/src/CMakeLists.txt new file mode 100644 index 000000000000..da48cd621651 --- /dev/null +++ b/experimental/split_mlir/src/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + defs + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + PUBLIC +) + +# Configures all iree_cc_* targets to take this implicit dep, +# which provides common includes and copts for the tree. +set(IREE_IMPLICIT_DEFS_CC_DEPS iree::experimental::split_mlir::src::defs) + +iree_add_all_subdirs() diff --git a/experimental/split_mlir/src/iree/CMakeLists.txt b/experimental/split_mlir/src/iree/CMakeLists.txt new file mode 100644 index 000000000000..33551b576974 --- /dev/null +++ b/experimental/split_mlir/src/iree/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() diff --git a/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt new file mode 100644 index 000000000000..63627bf3f76c --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt @@ -0,0 +1,52 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +iree_cc_library( + NAME + split_mlir_lib + HDRS + "Passes.h" + "Passes.h.inc" + SRCS + "Passes.cpp" + DEPS + ::PassesIncGen + MLIRFuncDialect + MLIRIR + MLIRPass + PUBLIC +) + +iree_cc_library( + NAME + registration + SRCS + "PluginRegistration.cpp" + DEPS + ::split_mlir_lib + MLIRIR + MLIRPass + iree::compiler::PluginAPI + PUBLIC +) + +iree_compiler_register_plugin( + PLUGIN_ID + split_mlir + TARGET + ::registration +) diff --git a/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h b/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h new file mode 100644 index 000000000000..1224dcd4818a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h @@ -0,0 +1,87 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/split_mlir/Passes.h" +#include "llvm/ADT/SmallSet.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DEF_MARKBISECT +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +namespace { + +void markRangeFirst(Operation& op, OpBuilder& builder) { + op.setAttr("outline_range_first", builder.getUnitAttr()); +} + +void markRangeLast(Operation& op, OpBuilder& builder) { + op.setAttr("outline_range_last", builder.getUnitAttr()); +} + +struct MarkBisectPass : public impl::MarkBisectBase { + using MarkBisectBase::MarkBisectBase; + + LogicalResult initialize(MLIRContext* context) override { + functionsSet.insert(functions.begin(), functions.end()); + return LogicalResult::success(); + } + + void runOnOperation() override { + mlir::func::FuncOp funcOp = getOperation(); + if (!functionsSet.contains(funcOp.getSymName())) { + return; + } + if (funcOp.getBody().getBlocks().size() > 1) { + return signalPassFailure(); + } + Block& entryBlock = funcOp.getBody().front(); + if (entryBlock.getOperations().size() < 3) { + // Degenerate case. Needs at least 1 op for each half + the return op. + return; + } + size_t opsCount = entryBlock.getOperations().size(); + size_t cutOpIndex = (opsCount - 1) / 2; + OpBuilder builder(&getContext()); + // Ranges are inclusive, [first, last]. + auto firstHalfLastOp = entryBlock.begin(); + std::advance(firstHalfLastOp, cutOpIndex - 1); + markRangeFirst(entryBlock.front(), builder); + markRangeLast(*firstHalfLastOp, builder); + auto secondHalfFirstOp = firstHalfLastOp; + std::advance(secondHalfFirstOp, 1); + markRangeFirst(*secondHalfFirstOp, builder); + auto secondHalfLastOp = entryBlock.end(); + // Take operation that is just before the return operation. + std::advance(secondHalfLastOp, -2); + markRangeLast(*secondHalfLastOp, builder); + } + + private: + llvm::SmallSet functionsSet; +}; + +} // namespace + +std::unique_ptr> createMarkBisectPass() { + return std::make_unique(); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h b/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h new file mode 100644 index 000000000000..783520e7fdb9 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h @@ -0,0 +1,299 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include + +#include "iree/split_mlir/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DEF_OUTLINEFUNCTIONS +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +namespace { + +// Collect all operation ranges that are marked for outlining. +// The begining of a range is marked with the outline_range_first attribute. +// The last operation of a range is marked with the outline_range_last attribue. +// Example: +// %0 = arith.addi %arg0, %arg1 {outline_range_first} : i32 +// %1 = arith.addi %arg2, %arg3 : i32 +// %2 = arith.muli %arg3, %arg4 {outline_range_last} : i32 +// The outline range will consist of the 3 operations. +LogicalResult getOutlineOpRanges( + Block& block, SmallVector, 4>& res) { + bool isInOutliningRange = false; + Block::iterator rangeBegin; + for (Block::iterator opIt = block.begin(); opIt != block.end(); ++opIt) { + if (opIt->hasAttr("outline_range_first")) { + if (isInOutliningRange) { + return LogicalResult::failure(); + } + isInOutliningRange = true; + rangeBegin = opIt; + } + + if (opIt->hasAttr("outline_range_last")) { + if (!isInOutliningRange) { + return LogicalResult::failure(); + } + isInOutliningRange = false; + res.emplace_back(rangeBegin, std::next(opIt)); + } + } + if (isInOutliningRange) { + // No matching closing marker outline_range_last. + return LogicalResult::failure(); + } + + return LogicalResult::success(); +} + +// Return all values that are an operand of some of the given ops that are +// produced by other ops. Also return all values that are a result of some of +// the given ops and have uses outside the ops range. +std::pair, SmallVector> +getOperandsAndResultsForIsolation(iterator_range opRange, + const SmallPtrSet& opsSet) { + SmallVector operands; + SmallVector results; + SmallPtrSet operandsSet; + SmallPtrSet resultsSet; + for (Operation& op : opRange) { + for (Value operand : op.getOperands()) { + if (!opsSet.contains(operand.getDefiningOp())) { + auto insertionResult = operandsSet.insert(operand); + if (insertionResult.second) { + operands.push_back(operand); + } + } + } + for (OpResult result : op.getResults()) { + for (OpOperand operand : result.getUsers()) { + if (!opsSet.contains(operand.getOwner())) { + auto insertionResult = resultsSet.insert(result); + if (insertionResult.second) { + results.push_back(result); + } + break; + } + } + } + } + return {operands, results}; +} + +template +void replaceValueUsesWithNewBlockArguments(ValueIt valuesBegin, + ValueIt valuesEnd, Block& block) { + for (ValueIt valIt = valuesBegin; valIt != valuesEnd; ++valIt) { + block.addArgument(valIt->getType(), valIt->getLoc()); + BlockArgument& blockArg = block.getArguments().back(); + valIt->replaceUsesWithIf(blockArg, [&block](OpOperand& operand) { + return operand.getOwner()->getBlock() == █ + }); + } +} + +void addBlockReturn(Block& block, ValueRange operands, OpBuilder& builder) { + func::ReturnOp returnOp = + builder.create(builder.getUnknownLoc(), operands); + block.push_back(returnOp); +} + +void moveOpsIntoBlock(iterator_range opRange, Block& block) { + // Put ops into another container because opRange will be invalidated during + // removal. + SmallVector ops; + std::transform(opRange.begin(), opRange.end(), std::back_inserter(ops), + [](Operation& op) { return &op; }); + for (Operation* op : ops) { + op->moveBefore(&block, block.end()); + } +} + +void moveBlock(Region& srcRegion, Region& destRegion, + Region::iterator srcBlockIt, Region::iterator destBlockIt) { + Block* block = srcRegion.getBlocks().remove(srcBlockIt); + destRegion.getBlocks().insert(destBlockIt, block); +} + +bool isAncestorOfBlock(Operation* op, Block* block) { + // Walk up the operation hierarchy and check each block. + while (op != nullptr) { + if (op->getBlock() == block) { + return true; + } + op = op->getParentOp(); + } + return false; +} + +template +void substititeUses(OriginalOpResultsIt originalBegin, + OriginalOpResultsIt originalEnd, NewOpResultsIt newBegin, + NewOpResultsIt newEnd, Block& excludedBlock) { + assert(std::distance(originalBegin, originalEnd) == + std::distance(newBegin, newEnd)); + auto newIt = newBegin; + for (auto originalIt = originalBegin; originalIt != originalEnd; + ++originalIt, ++newIt) { + originalIt->replaceUsesWithIf(*newIt, [&excludedBlock](OpOperand& operand) { + return !isAncestorOfBlock(operand.getOwner(), &excludedBlock); + }); + } +} + +// All operations in the range `opRange` are moved into a new function with name +// `name`. The resulting function is put inside `moduleOp` and is properly +// isolated from above. This does not insert a call to the new function in place +// of the moved operations. +func::FuncOp createFunctionFromOps(iterator_range opRange, + StringRef name, ModuleOp moduleOp, + SmallVector& rangeOperands, + SmallVector& rangeResults, + OpBuilder& builder) { + Region& region = *opRange.begin()->getParentRegion(); + Block& dstBlock = region.emplaceBlock(); + moveOpsIntoBlock(opRange, dstBlock); + replaceValueUsesWithNewBlockArguments(rangeOperands.begin(), + rangeOperands.end(), dstBlock); + addBlockReturn(dstBlock, + ArrayRef(rangeResults.begin(), rangeResults.end()), + builder); + func::FuncOp funcOp = builder.create( + builder.getUnknownLoc(), name, + FunctionType::get(builder.getContext(), dstBlock.getArgumentTypes(), + dstBlock.back().getOperandTypes())); + moduleOp.getBodyRegion().getBlocks().front().push_back(funcOp); + moveBlock(region, funcOp.getBody(), std::prev(region.end()), + funcOp.getBody().end()); + + return funcOp; +} + +void createCall(func::FuncOp funcOp, Block& block, Block::iterator pos, + SmallVector& rangeOperands, + SmallVector& rangeResults, OpBuilder& builder) { + func::CallOp callOp = builder.create( + builder.getUnknownLoc(), funcOp, + ArrayRef(rangeOperands.begin(), rangeOperands.end())); + block.getOperations().insert(pos, callOp); + substititeUses(rangeResults.begin(), rangeResults.end(), + callOp.getResults().begin(), callOp.getResults().end(), + funcOp.getBody().back()); +} + +std::optional outlineOpRange( + iterator_range opRange, StringRef name, ModuleOp moduleOp, + OpBuilder& builder) { + if (opRange.empty()) { + return std::nullopt; + } + + SmallPtrSet opsSet; + for (Operation& op : opRange) { + opsSet.insert(&op); + } + SmallVector rangeOperands; + SmallVector rangeResults; + std::tie(rangeOperands, rangeResults) = + getOperandsAndResultsForIsolation(opRange, opsSet); + Block& srcBlock = *opRange.begin()->getBlock(); + + func::FuncOp funcOp = createFunctionFromOps( + opRange, name, moduleOp, rangeOperands, rangeResults, builder); + createCall(funcOp, srcBlock, opRange.end(), rangeOperands, rangeResults, + builder); + + return funcOp; +} + +std::string getOutlinedFuncName(StringRef prefix, int blockIndex, + int outlineRangeIndex) { + return (Twine(prefix) + "_outline_" + Twine(blockIndex) + "_" + + Twine(outlineRangeIndex)) + .str(); +} + +void removeOutlineMarkers(iterator_range opRange) { + if (opRange.empty()) { + return; + } + opRange.begin()->removeAttr("outline_range_first"); + std::prev(opRange.end())->removeAttr("outline_range_last"); +} + +// Each marked operation range in `funcOp` is outlined into a new function. +// A call to the new function is inserted in place of the outlined operations. +LogicalResult outlineOpRanges(func::FuncOp funcOp, ModuleOp moduleOp, + OpBuilder& builder) { + Region& funcBody = funcOp.getFunctionBody(); + SmallVector, 4> outlineRanges; + for (auto blockIt : llvm::enumerate(funcBody.getBlocks())) { + outlineRanges.clear(); + if (failed(getOutlineOpRanges(blockIt.value(), outlineRanges))) { + return LogicalResult::failure(); + } + for (auto rangeIt : llvm::enumerate(outlineRanges)) { + removeOutlineMarkers(rangeIt.value()); + std::string name = getOutlinedFuncName(funcOp.getSymName(), + blockIt.index(), rangeIt.index()); + outlineOpRange(rangeIt.value(), name, moduleOp, builder); + } + } + + return LogicalResult::success(); +} + +struct OutlineFunctionsPass + : public impl::OutlineFunctionsBase { + using OutlineFunctionsBase::OutlineFunctionsBase; + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + Block& moduleBlock = *moduleOp.getBody(); + OpBuilder builder(&getContext()); + // Get all functions since we are going to insert new ones + // that we don't want to iterate over. + SmallVector funcOps( + moduleBlock.getOps().begin(), + moduleBlock.getOps().end()); + for (func::FuncOp op : funcOps) { + if (failed(outlineOpRanges(op, moduleOp, builder))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> createOutlineFunctionsPass() { + return std::make_unique(); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.cpp b/experimental/split_mlir/src/iree/split_mlir/Passes.cpp new file mode 100644 index 000000000000..50fd44ad024d --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.cpp @@ -0,0 +1,8 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/split_mlir/MarkBisectPassImpl.h" +#include "iree/split_mlir/OutlineFunctionsPassImpl.h" diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.h b/experimental/split_mlir/src/iree/split_mlir/Passes.h new file mode 100644 index 000000000000..039d00727365 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.h @@ -0,0 +1,37 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ +#define IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class ModuleOp; +namespace func { +class FuncOp; +} // namespace func + +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DECL +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +std::unique_ptr> createOutlineFunctionsPass(); +std::unique_ptr> createMarkBisectPass(); + +#define GEN_PASS_REGISTRATION +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +} // namespace split_mlir +} // namespace iree +} // namespace mlir + +#endif // IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.td b/experimental/split_mlir/src/iree/split_mlir/Passes.td new file mode 100644 index 000000000000..f07f37dfd5a6 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.td @@ -0,0 +1,79 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_SPLIT_MLIR_TRANSFORM_PASSES +#define IREE_SPLIT_MLIR_TRANSFORM_PASSES + +include "mlir/Pass/PassBase.td" + +def OutlineFunctions : + Pass<"iree-outline-functions", "mlir::ModuleOp"> { + let summary = "Outline operations in separate function(s)."; + let description = [{ + Marked operation ranges in a function block are outlined/moved into new functions. + In place of an outlined operations range is inserted a call to the new function. + The resulting function is equivalent to the original. + The ranges for outlining must be marked with the attributes + `outline_range_first`, `outline_range_last`. + + Example: + ```mlir + func.func @f(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 {outline_range_first} : i32 + %1 = arith.muli %0, %0 : i32 + %2 = arith.muli %1, %1 {outline_range_last} : i32 + %3 = arith.addi %2, %2 : i32 + return %3 : i32 + } + ``` + + The above MLIR will be transformed to: + ```mlir + func.func @f(%arg0: i32, %arg1: i32) -> i32 { + %0 = call @f_outline_0_0(%arg0, %arg1) : (i32, i32) -> i32 + %1 = arith.addi %0, %0 : i32 + return %1 : i32 + } + func.func @f_outline_0_0(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 : i32 + %1 = arith.muli %0, %0 : i32 + %2 = arith.muli %1, %1 : i32 + return %2 : i32 + } + ``` + + The pass will fail if there is branching to other function blocks + inside a marked operation range. + }]; + let constructor = "mlir::iree::split_mlir::createOutlineFunctionsPass()"; + let dependentDialects = ["mlir::func::FuncDialect"]; +} + +def MarkBisect : Pass<"iree-mark-bisect", "mlir::func::FuncOp"> { + let summary = "Mark operations in function(s) for outlining with bisect strategy."; + let description = [{ + Each function's entry block is bisected, + such that each piece has balanced number of ops. + The two pieces are marked with attributes `outline_range_first` and + `outline_range_last`. These markings surve as input to the `OutlineFunctions` pass. + + Example: + ```bash + iree-opt \ + --iree-plugin=split_mlir \ + --pass-pipeline="builtin.module(func.func(iree-mark-bisect{functions=f,g}))" + my.mlir + ``` + + }]; + let constructor = "mlir::iree::split_mlir::createMarkBisectPass()"; + let options = [ + ListOption<"functions", "functions", "std::string", "List of functions to bisect."> + ]; + let dependentDialects = ["mlir::func::FuncDialect"]; +} + +#endif // IREE_SPLIT_MLIR_TRANSFORM_PASSES diff --git a/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp b/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp new file mode 100644 index 000000000000..762bfe327b7d --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp @@ -0,0 +1,32 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/PluginAPI/Client.h" +#include "iree/split_mlir/Passes.h" + + +using namespace mlir; +using namespace mlir::iree_compiler; + +namespace { + +struct SplitMlirOptions { + void bindOptions(OptionsBinder &binder) {} +}; + +struct SplitMlirSession : public PluginSession { + static void registerPasses() { + iree::split_mlir::registerPasses(); + } +}; +} // namespace + +IREE_DEFINE_COMPILER_OPTION_FLAGS(SplitMlirOptions); + +extern "C" bool iree_register_compiler_plugin_split_mlir(PluginRegistrar *registrar) { + registrar->registerPlugin("split_mlir"); + return true; +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt new file mode 100644 index 000000000000..27c8e39e0be0 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "mark_bisect.mlir" + TOOLS + FileCheck + iree-opt +) + +iree_lit_test_suite( + NAME + lit + SRCS + "function_outlining.mlir" + TOOLS + FileCheck + iree-opt +) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir b/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir new file mode 100644 index 000000000000..cb8df4596192 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir @@ -0,0 +1,102 @@ +// RUN: iree-opt \ +// RUN: --split-input-file \ +// RUN: --iree-plugin=split_mlir \ +// RUN: --pass-pipeline="builtin.module(iree-outline-functions)" %s \ +// RUN: | FileCheck --dump-input-context=100 %s + +// Outline op that does not take any arguments and is not used anywhere. +// CHECK-LABEL: func.func @no_args_and_result +func.func @no_args_and_result() { +// CHECK: call @no_args_and_result_outline_0_0() : () -> () + %cts1 = mhlo.constant {outline_range_first, outline_range_last} dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: {{return$}} + return +} +// CHECK-LABEL: func.func @no_args_and_result_outline_0_0() +// CHECK: mhlo.constant dense<{{.+}}> : tensor<2xf32> +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last +// CHECK-NEXT: {{return$}} + +// ----- + +// Outline an op that takes one argument and has one result that is used. +// CHECK-LABEL: func.func @one_arg_and_one_result +// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>) -> tensor<2xf32> +func.func @one_arg_and_one_result(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NEXT: [[RES0:%.+]] = call @one_arg_and_one_result_outline_0_0([[ARG0]]) + %res = mhlo.cosine %arg0 {outline_range_first, outline_range_last} : tensor<2xf32> +// CHECK-NEXT: return [[RES0]] : tensor<2xf32> + return %res : tensor<2xf32> +} +// CHECK-LABEL: func.func @one_arg_and_one_result_outline_0_0 +// CHECK-SAME: ([[ARG1:%.+]]: tensor<2xf32>) -> tensor<2xf32> +// CHECK-NEXT: [[RES1:%.+]] = mhlo.cosine [[ARG1]] : tensor<2xf32> +// CHECK-NEXT: return [[RES1]] : tensor<2xf32> + +// ----- + +// Multiple ops in a range with multiple arguments and results. +// CHECK-LABEL: func.func @multiple_ops +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @multiple_ops(%arg0: i32, %arg1: i32) -> (i32, i32) { +// CHECK-NEXT: [[RES0:%.+]]:2 = call @multiple_ops_outline_0_0([[ARG0]], [[ARG1]]) : (i32, i32) -> (i32, i32) + %add = arith.addi %arg0, %arg0 {outline_range_first} : i32 + %mul = arith.muli %add, %arg1 {outline_range_last} : i32 +// CHECK-NEXT: return [[RES0]]#0, [[RES0]]#1 : i32, i32 + return %add, %mul : i32, i32 +} +// CHECK-LABEL: func.func @multiple_ops_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32, [[ARG11:%.+]]: i32) -> (i32, i32) +// CHECK-NEXT: [[ADD:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: [[MUL:%.+]] = arith.muli [[ADD]], [[ARG11]] : i32 +// CHECK-NEXT: return [[ADD]], [[MUL]] : i32, i32 + +// ----- + +// Outline multiple ranges in the same function. +// CHECK-LABEL: func.func @multiple_ranges_in_same_func +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @multiple_ranges_in_same_func(%arg0: i32, %arg1: i32) -> (i32, i32) { +// CHECK-NEXT: [[ADD:%.+]] = call @multiple_ranges_in_same_func_outline_0_0([[ARG0]]) : (i32) -> i32 + %add = arith.addi %arg0, %arg0 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: [[MUL:%.+]] = call @multiple_ranges_in_same_func_outline_0_1([[ADD]], [[ARG1]]) : (i32, i32) -> i32 + %mul = arith.muli %add, %arg1 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: return [[ADD]], [[MUL]] : i32, i32 + return %add, %mul : i32, i32 +} +// CHECK-LABEL: func.func @multiple_ranges_in_same_func_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32) -> i32 +// CHECK-NEXT: [[ADD1:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: return [[ADD1]] : i32 +// CHECK-LABEL: func.func @multiple_ranges_in_same_func_outline_0_1 +// CHECK-SAME: ([[ARG20:%.+]]: i32, [[ARG21:%.+]]: i32) -> i32 +// CHECK-NEXT: [[MUL2:%.+]] = arith.muli [[ARG20]], [[ARG21]] : i32 +// CHECK-NEXT: return [[MUL2]] : i32 + +// ----- + +// Outline multiple ranges in different blocks. +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 +func.func @multiple_ranges_in_different_blocks(%arg0: i32, %arg1: i32) -> i32 { +// CHECK-NEXT: [[ADD:%.+]] = call @multiple_ranges_in_different_blocks_outline_0_0([[ARG0]]) : (i32) -> i32 + %add = arith.addi %arg0, %arg0 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: cf.br ^bb1([[ARG1]] : i32) + cf.br ^bb1(%arg1 : i32) +// CHECK-NEXT: ^bb1 +// CHECK-SAME: ([[ARG2:%.+]]: i32) +^bb1 (%arg2: i32): +// CHECK-NEXT: [[MUL:%.+]] = call @multiple_ranges_in_different_blocks_outline_1_0([[ADD]], [[ARG2]]) : (i32, i32) -> i32 + %mul = arith.muli %add, %arg2 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: return [[MUL]] : i32 + return %mul : i32 +} +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32) -> i32 +// CHECK-NEXT: [[ADD1:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: return [[ADD1]] : i32 +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks_outline_1_0 +// CHECK-SAME: ([[ARG20:%.+]]: i32, [[ARG21:%.+]]: i32) -> i32 +// CHECK-NEXT: [[MUL2:%.+]] = arith.muli [[ARG20]], [[ARG21]] : i32 +// CHECK-NEXT: return [[MUL2]] : i32 diff --git a/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir b/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir new file mode 100644 index 000000000000..3a4cb72424a7 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir @@ -0,0 +1,68 @@ +// RUN: iree-opt \ +// RUN: --split-input-file \ +// RUN: --iree-plugin=split_mlir \ +// RUN: --pass-pipeline="builtin.module(func.func(iree-mark-bisect{functions=two_ops,too_few_ops,multiple_ops}))" %s \ +// RUN: | FileCheck %s + +// Each operation is marked as separate range. +// CHECK-LABEL: func.func @two_ops +func.func @two_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: mhlo.constant +// CHECK-DAG: outline_range_first +// CHECK-DAG: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: mhlo.add +// CHECK-DAG: outline_range_first +// CHECK-DAG: outline_range_last + %res = mhlo.add %arg0, %cts1 : tensor<2xf32> + return %res : tensor<2xf32> +} + +// ----- + +// Degenerate case with too few ops should not mark enything. +// CHECK-LABEL: func.func @too_few_ops +func.func @too_few_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: mhlo.constant +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: return +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + return %cts1 : tensor<2xf32> +} + +// ----- + +// Multiple ops per range. +// CHECK-LABEL: func.func @multiple_ops +func.func @multiple_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: outline_range_first +// CHECK-SAME: dense<1.000000e+00> + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_last +// CHECK-SAME: dense<2.000000e+00> + %cts2 = mhlo.constant dense<2.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_first +// CHECK-SAME: dense<3.000000e+00> + %cts3 = mhlo.constant dense<3.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_last +// CHECK-SAME: dense<4.000000e+00> + %cts4 = mhlo.constant dense<4.000000e+00> : tensor<2xf32> +// CHECK-NEXT: return + return %cts1 : tensor<2xf32> +} + +// ----- + +// Non-listed functions should not be marked. +// CHECK-LABEL: func.func @function_not_to_mark +func.func @function_not_to_mark(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> + %res = mhlo.add %arg0, %cts1 : tensor<2xf32> +// CHECK: return + return %res : tensor<2xf32> +} From 1269f37e1d107cef854449fd545603c2ecaabf6b Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 5 Apr 2023 08:39:26 -0700 Subject: [PATCH 09/38] [split_mlir] Add operation list extraction and execution for IREE Expose a Python binding that has extraction of an operation list from an MLIR file. This list is then used to execute with IREE the entry MLIR while resolving calls to functions in other MLIR files. --- .../src/iree/split_mlir/CMakeLists.txt | 28 +++ .../src/iree/split_mlir/OperationListImpl.h | 181 ++++++++++++++++++ .../src/iree/split_mlir/SplitMlirPyExt.cpp | 37 ++++ .../src/iree/split_mlir/__init__.py | 14 ++ .../src/iree/split_mlir/_split_mlir.pyi | 12 ++ .../src/iree/split_mlir/execution.py | 42 ++++ .../src/iree/split_mlir/iree_execution.py | 122 ++++++++++++ .../split_mlir/test/execution/CMakeLists.txt | 18 ++ .../iree/split_mlir/test/execution/entry.mlir | 14 ++ .../test/execution/execution_test.py | 57 ++++++ .../iree/split_mlir/test/execution/f1.mlir | 10 + .../iree/split_mlir/test/execution/f2.mlir | 10 + .../split_mlir/src/iree/split_mlir/types.py | 19 ++ 13 files changed, 564 insertions(+) create mode 100644 experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp create mode 100644 experimental/split_mlir/src/iree/split_mlir/__init__.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi create mode 100644 experimental/split_mlir/src/iree/split_mlir/execution.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/iree_execution.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/types.py diff --git a/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt index 63627bf3f76c..d59a3df68730 100644 --- a/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt +++ b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt @@ -50,3 +50,31 @@ iree_compiler_register_plugin( TARGET ::registration ) + +iree_pyext_module( + NAME + PyExt + MODULE_NAME _split_mlir + SRCS + "OperationListImpl.h" + "SplitMlirPyExt.cpp" + DEPS + MLIRFuncDialect + MLIRIR + MLIRAsmParser + iree::compiler::Tools::init_passes_and_dialects +) + +iree_py_library( + NAME + split_mlir_py + SRCS + "__init__.py" + "_split_mlir.pyi" + "execution.py" + "iree_execution.py" + "types.py" + DEPS + MLIRPythonModules + ::PyExt +) diff --git a/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h b/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h new file mode 100644 index 000000000000..df2f4180a2d3 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h @@ -0,0 +1,181 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "iree/compiler/Tools/init_dialects.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +using OpId = std::string; +using OpIndex = size_t; +using ResultIndex = size_t; +using ResultId = std::tuple; +using Arguments = std::vector; +using OperationList = std::vector>; + +ResultIndex getResultIndex(OpOperand& operand) { + OpResult opResult = operand.get().dyn_cast(); + if (opResult) { + return opResult.getResultNumber(); + } + + BlockArgument blockArgument = operand.get().dyn_cast(); + assert(blockArgument); + return blockArgument.getArgNumber(); +} + +FailureOr getDefiningOpIndex( + OpOperand& operand, Block& block, + const std::unordered_map& operationInBlockIndexMap) { + Value value = operand.get(); + if (value.isa()) { + return 0; + } + + OpResult opResult = value.dyn_cast(); + if (!opResult) { + operand.getOwner()->emitError( + Twine("Operand ") + std::to_string(operand.getOperandNumber()) + + "is neigher a block argument or a result of an operation"); + return failure(); + } + if (value.getDefiningOp()->getBlock() != &block) { + operand.getOwner()->emitError( + "Can't extract call graph for block that is not isolated from above."); + return failure(); + } + + auto it = operationInBlockIndexMap.find(value.getDefiningOp()); + assert(it != operationInBlockIndexMap.end()); + return it->second; +} + +std::string getOpId(Operation& op) { + func::CallOp callOp = dyn_cast(op); + if (callOp) { + return (Twine("call ") + callOp.getCallee()).str(); + } + + if (isa(op)) { + return "return"; + } + + return op.getName().getStringRef().str(); +} + +FailureOr extractOperationList(Block& block) { + OperationList res; + // Block arguments don't depend on anything. + res.emplace_back(); + // Index inside the block. + std::unordered_map operationInBlockIndexMap; + + for (auto opIt : llvm::enumerate(block)) { + operationInBlockIndexMap.insert({&opIt.value(), opIt.index() + 1}); + OpId id = getOpId(opIt.value()); + Arguments arguments; + for (OpOperand& operand : opIt.value().getOpOperands()) { + FailureOr opIndex = + getDefiningOpIndex(operand, block, operationInBlockIndexMap); + FailureOr resultIndex = getResultIndex(operand); + if (failed(opIndex) || failed(resultIndex)) { + return failure(); + } + arguments.emplace_back(opIndex.value(), resultIndex.value()); + } + res.emplace_back(id, arguments); + } + + return res; +} + +FailureOr> loadMlir(const char* mlirFilePath, + MLIRContext& context) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(mlirFilePath); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + return parseSourceFile(sourceMgr, &context); +} + +func::FuncOp findFunction(Operation* root, StringRef name) { + func::FuncOp res; + root->walk([&](func::FuncOp op) { + if (op.getSymName() == name) { + res = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return res; +} + +FailureOr extractOperationList(ModuleOp moduleOp, + StringRef functionName) { + func::FuncOp funcOp = findFunction(moduleOp.getOperation(), functionName); + Region* region = funcOp.getCallableRegion(); + if (!region) { + funcOp.emitError("No callable region found."); + return failure(); + } + if (region->getBlocks().size() != 1) { + funcOp.emitError("Blocks count must be exactly 1."); + return failure(); + } + return extractOperationList(region->front()); +} + +FailureOr extractOperationList(const char* mlirFilePath, + StringRef functionName, + MLIRContext& context) { + auto moduleOp = loadMlir(mlirFilePath, context); + if (failed(moduleOp)) { + return failure(); + } + + return extractOperationList(moduleOp->get(), functionName); +} + +std::unique_ptr makeMlirContext() { + mlir::DialectRegistry registry; + mlir::iree_compiler::registerAllDialects(registry); + auto context = std::make_unique(registry); + return context; +} + +FailureOr extractOperationList(const char* mlirFilePath, + StringRef functionName) { + auto context = makeMlirContext(); + return extractOperationList(mlirFilePath, functionName, *context); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp b/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp new file mode 100644 index 000000000000..d1b5560fe660 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp @@ -0,0 +1,37 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include + +#include "OperationListImpl.h" + +namespace py = pybind11; + +namespace mlir { +namespace iree { +namespace split_mlir { + +PYBIND11_MODULE(_split_mlir, m) { + m.doc() = "Split MLIR C++ extension"; + + m.def( + "extract_operation_list", + [](const std::string& mlirFilePath, const std::string& functionName) { + auto res = extractOperationList(mlirFilePath.c_str(), functionName); + if (failed(res)) { + throw std::runtime_error(""); + } + return res.value(); + }, + py::arg("mlir_file_path"), py::arg("function_name")); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/__init__.py b/experimental/split_mlir/src/iree/split_mlir/__init__.py new file mode 100644 index 000000000000..f66106cd2b4a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._split_mlir import extract_operation_list +from .execution import execute_operation_list +from .iree_execution import IreeExecutor, execute_mlir_with_iree + +__all__ = [ + "execute_operation_list", "execute_mlir_with_iree", + "extract_operation_list", "IreeExecutor" +] diff --git a/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi b/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi new file mode 100644 index 000000000000..f21ca121c93e --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi @@ -0,0 +1,12 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .types import OperationList + + +def extract_operation_list(mlir_file_path: str, + function_name: str) -> OperationList: + ... diff --git a/experimental/split_mlir/src/iree/split_mlir/execution.py b/experimental/split_mlir/src/iree/split_mlir/execution.py new file mode 100644 index 000000000000..ddf8988fb39e --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/execution.py @@ -0,0 +1,42 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Optional +from .types import OpArguments, Tensor, ExecuteOp, OperationList + + +def collect_arguments(arguments: OpArguments, + results: List[List[Tensor]]) -> List[Tensor]: + return [results[arg[0]][arg[1]] for arg in arguments] # type: ignore + + +def execute_operation_list( + input: List[Tensor], + operation_list: OperationList, + execute_op: ExecuteOp, + override_results: Optional[List[List[Tensor]]] = None +) -> List[List[Tensor]]: + """Algorithm to execute a call list. + + Parameters + ---------- + input : Input of the graph. + execute_op : Callable that executes an operation from the graph. + override_results : When execting operations override arguments with this values, + instead of using the computed resuts from previous functions. + + Returns + ------- + All results from all operations is the graph are in the same order + as they appear in `operation_list`. `input` is prepened to the result. + """ + results = [input] + for op in operation_list[1:]: + arguments = collect_arguments( + arguments=op[1], + results=results if override_results is None else override_results) + results.append(execute_op(op[0], arguments)) + return results diff --git a/experimental/split_mlir/src/iree/split_mlir/iree_execution.py b/experimental/split_mlir/src/iree/split_mlir/iree_execution.py new file mode 100644 index 000000000000..cfcfe92a3a90 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/iree_execution.py @@ -0,0 +1,122 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Callable, Tuple, Dict, Optional, Any +from .types import Tensor +import iree.runtime +from iree.runtime import VmModule, HalDevice, load_vm_module +from .execution import execute_operation_list +from tempfile import TemporaryDirectory +from iree.compiler.tools import compile_file +import os +from pathlib import Path +from collections import namedtuple +from ._split_mlir import extract_operation_list +from numbers import Number + +VmfbFilePath = str +MlirFilePath = str +FunctionName = str + + +class IreeExecutor: + """Executor for IREE that implements the `.types.ExecuteOp` interface.""" + + def __init__(self, device: HalDevice, + resolve_function: Callable[[FunctionName], Tuple[VmfbFilePath, + FunctionName]]): + """ + Parameters + ---------- + resolve_function : Resolves a function name that is called in the entry MLIR to + the vmfb file where it can be found under another name. + """ + self.device = device + self.resolve_function = resolve_function + + def __call__(self, op_id: str, operands: List[Tensor]) -> List[Tensor]: + if op_id.startswith("call "): + function_name = op_id.split(" ", 2)[1] + vmfb_file_path, vmfb_function_name = self.resolve_function(function_name) + config = iree.runtime.Config(device=self.device) + with open(vmfb_file_path, "rb") as f: + vm_flatbuffer = f.read() + vm_module_fb_bytes = VmModule.from_flatbuffer(config.vm_instance, + vm_flatbuffer) + vm_module = load_vm_module(vm_module_fb_bytes, config) + res = getattr(vm_module, vmfb_function_name)(*operands) + if isinstance(res, (iree.runtime.DeviceArray, Number)): + res = [res] + return res + if op_id == "return": + return operands + raise RuntimeError(f"Invalid op_id \"{op_id}\".") + + +def mlir_to_vmfb_file_path(mlir_file_path: str) -> str: + return f"{Path(mlir_file_path).stem}.vmfb" + + +def execute_mlir_with_iree(input: List[Tensor], + mlir_path_function_pairs: List[Tuple[MlirFilePath, + FunctionName]], + compile_kwargs: Dict[str, Any], + device: HalDevice, + override_results: Optional[List[ + List[Tensor]]] = None, + artifact_dir: Optional[str] = None) -> List[Tensor]: + """Executes an MLIR program that is split accorss multiple MLIR files. + Parameters + ---------- + mlir_path_function_pairs : List of MLIR files and the function they contain. + The first element is the entry MLIR and function. + It is expected that a name of function called in the entry function correspnd + to an MLIR file with the same name without file name extension. + compile_kwargs : Compile arguments to pass to iree.compiler.tools.compile_file. + artifact_dir : Where to put temporary files. + Defaults to creating a unique temporary directory that is deleted on completion. + + See: `execute_operation_list` + """ + if artifact_dir is None: + with TemporaryDirectory() as temp_dir: + return execute_mlir_with_iree( + input=input, + mlir_path_function_pairs=mlir_path_function_pairs, + override_results=override_results, + compile_kwargs=compile_kwargs, + device=device, + artifact_dir=temp_dir) + + entry_mlir_file_path = mlir_path_function_pairs[0][0] + entry_function_name = mlir_path_function_pairs[0][1] + FunctionDescription = namedtuple( + "FunctionDescription", + ["mlir_file_path", "vmfb_file_path", "function_name"]) + function_map = { + Path(Path(p[0]).name).stem: FunctionDescription( + p[0], os.path.join(artifact_dir, mlir_to_vmfb_file_path(p[0])), p[1]) + for p in mlir_path_function_pairs + } + for i in range(1, len(mlir_path_function_pairs)): + function_description = function_map[Path( + Path(mlir_path_function_pairs[i][0]).name).stem] + compile_file(function_description.mlir_file_path, + output_file=function_description.vmfb_file_path, + **compile_kwargs) + + def resolve_function( + function_name: FunctionName) -> Tuple[VmfbFilePath, FunctionName]: + func_desc = function_map[function_name] + return (func_desc.vmfb_file_path, func_desc.function_name) + + executor = IreeExecutor(device=device, resolve_function=resolve_function) + operation_list = extract_operation_list(mlir_file_path=entry_mlir_file_path, + function_name=entry_function_name) + return execute_operation_list(operation_list=operation_list, + execute_op=executor, + input=input, + override_results=override_results) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt new file mode 100644 index 000000000000..77611fc4278a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_local_py_test( + NAME + execution_test + SRC + execution_test.py + PACKAGE_DIRS + "${IREE_BINARY_DIR}/compiler/bindings/python" + "${IREE_BINARY_DIR}/runtime/bindings/python" + "${IREE_BINARY_DIR}/compiler/plugins/split_mlir" +) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir new file mode 100644 index 000000000000..ea6cd8b5f4e5 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir @@ -0,0 +1,14 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func nested @f1(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) +func.func nested @f2(%arg0: tensor<1xf32>) -> tensor<1xf32> + +func.func @caller(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { + %0:2 = call @f1(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) + %1 = call @f2(%0#0) : (tensor<1xf32>) -> tensor<1xf32> + return %arg1, %1 : tensor<1xf32>, tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py b/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py new file mode 100644 index 000000000000..98e5b766eacd --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from iree.split_mlir import extract_operation_list, execute_mlir_with_iree +from typing import List, Any +import os +import iree.runtime +import numpy as np + + +def assert_nested_array_equals(a: List[Any], b: List[Any]): + assert a == b, f"{a} != {b}" + + +class ExecutionTest(unittest.TestCase): + + def test_extract_operation_list(self): + expected_operation_list = [ + ("", []), + ("call f1", [(0, 0), (0, 0)]), + ("call f2", [(1, 0)]), + ("return", [(0, 1), (2, 0)]), + ] + operation_list = extract_operation_list(mlir_file_path=os.path.join( + os.path.dirname(__file__), "entry.mlir"), + function_name="caller") + assert_nested_array_equals(expected_operation_list, operation_list) + + def test_mlir_execution(self): + mlir_path_function_pairs = [ + (os.path.join(os.path.dirname(__file__), "entry.mlir"), "caller"), + (os.path.join(os.path.dirname(__file__), "f1.mlir"), "f1"), + (os.path.join(os.path.dirname(__file__), "f2.mlir"), "main"), + ] + compile_kwargs = { + "target_backends": ["llvm-cpu"], + } + device = iree.runtime.get_device("local-task") + input = [np.array([1], dtype=np.float32), np.array([2], dtype=np.float32)] + results = execute_mlir_with_iree( + input=input, + mlir_path_function_pairs=mlir_path_function_pairs, + compile_kwargs=compile_kwargs, + device=device) + expected_output = [ + np.array([2], dtype=np.float32), + np.array([4], dtype=np.float32) + ] + assert_nested_array_equals(results[-1], expected_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir new file mode 100644 index 000000000000..8ca9609123a7 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir @@ -0,0 +1,10 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func @f1(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { + %0 = arith.addf %arg0, %arg1: tensor<1xf32> + return %0, %0 : tensor<1xf32>, tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir new file mode 100644 index 000000000000..6b0bc790fdbb --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir @@ -0,0 +1,10 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func @main(%arg0: tensor<1xf32>) -> tensor<1xf32> { + %0 = arith.addf %arg0, %arg0 : tensor<1xf32> + return %0 : tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/types.py b/experimental/split_mlir/src/iree/split_mlir/types.py new file mode 100644 index 000000000000..418bc207ab91 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/types.py @@ -0,0 +1,19 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, TypeVar, Callable, Tuple +from numbers import Integral + +Tensor = TypeVar("Tensor") +OpId = TypeVar("OpId") +ExecuteOp = Callable[[OpId, List[Tensor]], List[Tensor]] +OperationIndex = Integral +ResultIndex = Integral +"""Description of the dependencies of an operation.""" +OpArguments = List[Tuple[OperationIndex, ResultIndex]] +Operation = Tuple[OpId, OpArguments] +"""Describes a dependency graph of operations.""" +OperationList = List[Operation] From 3c08f6ba9f2d4529a83ca526f804818bf4b40973 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 22 Mar 2023 19:18:50 -0400 Subject: [PATCH 10/38] [Preprocessing] Add pass to generalize 2d convolutions This can be useful when trying to do layout propagation and guaranteeing specific fusion at time (use with caution). --- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/GeneralizeConvolutions.cpp | 68 +++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 3 + .../compiler/Preprocessing/Common/Passes.td | 6 ++ 5 files changed, 79 insertions(+) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index f8582d347074..3490034349f1 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -34,6 +34,7 @@ iree_compiler_cc_library( "ConvertConvNchwToNhwc.cpp", "ConvertLinalgMatmulToMmt.cpp", "GeneralizeAndFuse.cpp", + "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", "PassDetail.h", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index b5bd22be2d72..1ef168c44ce3 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -30,6 +30,7 @@ iree_cc_library( "ConvertConvNchwToNhwc.cpp" "ConvertLinalgMatmulToMmt.cpp" "GeneralizeAndFuse.cpp" + "GeneralizeConvolutions.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp new file mode 100644 index 000000000000..ada980a788ed --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp @@ -0,0 +1,68 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +namespace { + +template +class GeneralizeTargetNamedOp final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LinalgOpType linalgOp, + PatternRewriter &rewriter) const override { + FailureOr genericOp = + linalg::generalizeNamedOp(rewriter, linalgOp); + if (failed(genericOp)) return failure(); + return success(); + } +}; + +struct GeneralizeConvolutionsPass + : GeneralizeConvolutionsBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(&getContext()); + patterns.insert>(context); + patterns.insert>(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr createGeneralizeConvolutionsPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index 4831ff8d1ce0..4a03127604ea 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -34,6 +34,9 @@ std::unique_ptr createGeneralizeAndFusePass(); std::unique_ptr> createMakeSingleDispatchForFunctionPass(); +// A pass to generalize all conv-like ops. +std::unique_ptr createGeneralizeConvolutionsPass(); + // A pass to pad linalg ops to the next integer multiple of `paddingSize`. std::unique_ptr createPadLinalgOpsToIntegerMultiplePass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index ca1c9913d21e..6e83b0a582e6 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -40,6 +40,12 @@ def MakeSingleDispatchForFunction : let constructor = "mlir::iree_compiler::Preprocessing::createMakeSingleDispatchForFunctionPass()"; } +def GeneralizeConvolutions : + Pass<"iree-preprocessing-generalize-convolutions", ""> { + let summary = "Generalize all convolution ops"; + let constructor = "mlir::iree_compiler::IREE::createGeneralizeConvolutionsPass()"; +} + def PadLinalgOps : Pass<"iree-preprocessing-pad-linalg-ops", ""> { let summary = "Pad linalg ops to the next integer multiple of paddingSize."; From 8ee95663508e786ed19dbac47008fad3e9ee43c3 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 16 Feb 2023 14:57:11 -0500 Subject: [PATCH 11/38] [Preprocessing] Add a pass to convert convolutions to channels last This pass is the spiritual successor to `convert-conv-nchw-to-nhwc` focused on generalizing to enable data tiling and more robust layout propagation, as well as supporting non-named convolutions as well. Currently this includes some baked in generalization patterns and does not support padding. Tile size selection currently is pass-wide, but there is limited attribute control to enable fully transposing. Further generalizations should aim to write this pass by allowing per-op tile size control. --- .../compiler/Preprocessing/Common/BUILD.bazel | 3 +- .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/ConvertConvToChannelsLast.cpp | 890 ++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 3 + .../compiler/Preprocessing/Common/Passes.td | 12 + 5 files changed, 908 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 3490034349f1..08795e578b27 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -32,8 +32,9 @@ iree_compiler_cc_library( srcs = [ "ConvertConv2DToImg2Col.cpp", "ConvertConvNchwToNhwc.cpp", + "ConvertConvToChannelsLast.cpp", "ConvertLinalgMatmulToMmt.cpp", - "GeneralizeAndFuse.cpp", + "GeneralizeAndFuse.cpp", "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 1ef168c44ce3..1b83bc1ea1d7 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -28,6 +28,7 @@ iree_cc_library( SRCS "ConvertConv2DToImg2Col.cpp" "ConvertConvNchwToNhwc.cpp" + "ConvertConvToChannelsLast.cpp" "ConvertLinalgMatmulToMmt.cpp" "GeneralizeAndFuse.cpp" "GeneralizeConvolutions.cpp" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp new file mode 100644 index 000000000000..4a8b833bfb1a --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp @@ -0,0 +1,890 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-preprocessing-convert-conv-to-channels-last" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +static const StringLiteral fullTileTransposeMarker = "__fully_transpose_tile__"; + +using TransposeIndices = SmallVector; +using ConvBuilderFn = std::function newDimOrder, + SmallVector newIteratorTypes)>; +using linalg::detail::MatchConvolutionResult; + +static Value defaultConvBuilderFn( + OpBuilder &b, Location loc, linalg::LinalgOp srcConv, Value input, + Value filter, Value output, AffineMap inputMap, AffineMap filterMap, + AffineMap outputMap, SmallVector newDimOrder, + SmallVector newIteratorTypes) { + AffineMap newInputMap = inputMap; + AffineMap newFilterMap = filterMap; + AffineMap newOutputMap = outputMap; + if (!newDimOrder.empty()) { + DenseMap dimMap; + for (auto [newDim, oldDim] : llvm::enumerate(newDimOrder)) + dimMap[b.getAffineDimExpr(oldDim)] = b.getAffineDimExpr(newDim); + newInputMap = inputMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + newFilterMap = filterMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + newOutputMap = outputMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + } + SmallVector iterators = srcConv.getIteratorTypesArray(); + iterators.append(newIteratorTypes); + auto genericConv = b.create( + loc, output.getType(), ValueRange{input, filter}, output, + ArrayRef{newInputMap, newFilterMap, newOutputMap}, iterators); + IRMapping mapper; + srcConv->getRegion(0).cloneInto(&genericConv.getRegion(), mapper); + return genericConv.getResult(0); +} + +template +static Value namedConvBuilderFn( + OpBuilder &b, Location loc, linalg::LinalgOp srcConv, Value input, + Value filter, Value output, AffineMap inputMap, AffineMap filterMap, + AffineMap outputMap, SmallVector newDimOrder, + SmallVector newIteratorTypes) { + sourceNamedConvTy namedConv = cast(srcConv); + return b + .create( + loc, output.getType(), ValueRange{input, filter}, output, + namedConv.getStrides(), namedConv.getDilations()) + .getResult(0); +} + +static TransposeIndices getNormalizedIndices(TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + TransposeIndices normalized(targetIndices.size()); + for (auto i : llvm::enumerate(targetIndices)) + normalized[i.index()] = i.value() - startDim; + return normalized; +} + +static TransposeIndices invertIndices(TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + TransposeIndices inverted(targetIndices.size()); + for (auto i : llvm::enumerate(targetIndices)) { + inverted[i.value() - startDim] = i.index() + startDim; + } + return inverted; +} + +static bool isInnerIdentityIndices(TransposeIndices indices, int64_t rank) { + return indices.empty() || + (llvm::all_of(llvm::enumerate(indices), + [indices](auto e) { + if (e.index() == 0) return true; + return indices[e.index() - 1] < e.value(); + }) && + indices.back() == rank - 1); +} + +// Helper to shuffle vectors according to the transpose indices. +template +static SmallVector shuffleFromIndices(SmallVector unshuffled, + TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + SmallVector shuffled(unshuffled); + for (auto i : llvm::enumerate(targetIndices)) { + shuffled[i.index() + startDim] = unshuffled[i.value()]; + } + return shuffled; +} + +template +static SmallVector getPackedVector(SmallVector vec, + TransposeIndices targetIndices) { + SmallVector packedShape; + for (auto [i, val] : llvm::enumerate(vec)) + if (!llvm::is_contained(targetIndices, i)) packedShape.push_back(val); + for (auto i : targetIndices) packedShape.push_back(vec[i]); + return packedShape; +} + +static SmallVector getUntiledPackReassociationMap( + TransposeIndices targetIndices, int64_t rank) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + int dimCount = targetIndices.size(); + SmallVector reassociationMap; + for (int i = 0; i <= startDim; i++) reassociationMap.push_back({i}); + for (int i = startDim + 1; i < dimCount + startDim + 1; i++) + reassociationMap[startDim].push_back(i); + for (int i = dimCount + startDim + 1; i < dimCount + rank; i++) + reassociationMap.push_back({i}); + return reassociationMap; +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static std::tuple, AffineMap> +createTransposeAsTensorPack( + PatternRewriter &rewriter, Location loc, Value input, AffineMap inputMap, + TransposeIndices targetIndices, int tilingFactor, + llvm::DenseMap innerDimToDomainDim) { + if (isInnerIdentityIndices(targetIndices, inputMap.getNumResults())) + return std::make_tuple(input, std::nullopt, inputMap); + + RankedTensorType inType = input.getType().cast(); + auto elementType = inType.getElementType(); + auto inputShape(inType.getShape()); + + SmallVector transposedTileSizes( + targetIndices.size(), rewriter.getIndexAttr(tilingFactor)); + if (tilingFactor <= 0) { + for (auto [index, i] : llvm::enumerate(targetIndices)) { + if (ShapedType::isDynamic(inputShape[i])) + transposedTileSizes[index] = + rewriter.create(loc, input, i).getResult(); + else + transposedTileSizes[index] = rewriter.getIndexAttr(inputShape[i]); + } + } + + // Pack the input tensor. + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, input, transposedTileSizes, targetIndices, + SmallVector{}); + auto packedInput = rewriter.create( + loc, input, empty, targetIndices, transposedTileSizes, + /*padding=*/std::nullopt, SmallVector{}); + + SmallVector mapResults(inputMap.getResults()); + AffineMap transposedMap; + + Value packedOperand = packedInput; + // Collapse the unit dims created by tensor.pack. + if (tilingFactor <= 0) { + auto reassociationMap = + getUntiledPackReassociationMap(targetIndices, inType.getRank()); + auto transposedInputShape = + getPackedVector(llvm::to_vector(inputShape), targetIndices); + packedOperand = + rewriter + .create( + loc, RankedTensorType::get(transposedInputShape, elementType), + packedOperand, reassociationMap) + .getResult(); + transposedMap = + AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), + getPackedVector(mapResults, targetIndices), + input.getContext()); + } else { + for (auto innerDim : targetIndices) { + mapResults.push_back(rewriter.getAffineDimExpr( + innerDimToDomainDim[inputMap.getDimPosition(innerDim)])); + } + transposedMap = AffineMap::get( + inputMap.getNumDims() + innerDimToDomainDim.size(), + inputMap.getNumSymbols(), mapResults, input.getContext()); + } + + return std::make_tuple(packedOperand, packedInput, transposedMap); +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static Value createTransposeAsTensorUnPack(PatternRewriter &rewriter, + Location loc, Value output, + tensor::PackOp packOp, + int tilingFactor) { + Value packedOutput = output; + if (tilingFactor <= 0) { + RankedTensorType outType = output.getType().cast(); + auto elementType = outType.getElementType(); + auto outputShape(outType.getShape()); + int64_t rank = outType.getRank(); + TransposeIndices targetIndices(packOp.getInnerDimsPos()); + + int startDim = + *std::min_element(targetIndices.begin(), targetIndices.end()); + SmallVector expandedOutputShape; + for (int i = 0, e = startDim; i < e; i++) + expandedOutputShape.push_back(outputShape[i]); + for (int i = 0, e = targetIndices.size(); i < e; i++) + expandedOutputShape.push_back(1); + for (int i = startDim, e = rank; i < e; i++) + expandedOutputShape.push_back(outputShape[i]); + + auto reassociationMap = getUntiledPackReassociationMap(targetIndices, rank); + packedOutput = + rewriter + .create( + loc, RankedTensorType::get(expandedOutputShape, elementType), + output, reassociationMap) + .getResult(); + } + + Value empty = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedOutput, packOp.getMixedTiles(), + packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); + + auto unpackedOutput = rewriter.create( + loc, packedOutput, empty, packOp.getInnerDimsPos(), + packOp.getMixedTiles(), packOp.getOuterDimsPerm()); + unpackedOutput->setAttr("__unpack__", rewriter.getUnitAttr()); + return unpackedOutput.getResult(); +} + +static TransposeIndices collectChannelTransposeIndices( + AffineMap map, SmallVector> transposeDimTargets) { + SmallVector channelIndices(transposeDimTargets.size()); + for (auto [index, result] : llvm::enumerate(map.getResults())) { + if (result.isa()) { + for (auto [channelVec, dimCategory] : + llvm::zip_equal(channelIndices, transposeDimTargets)) { + if (llvm::is_contained(dimCategory, + result.cast().getPosition())) { + channelVec.push_back(index); + break; + } + } + } + } + + TransposeIndices indices; + for (auto channelVec : channelIndices) indices.append(channelVec); + return indices; +} + +static LogicalResult transposeConvLikeLinalgOp( + PatternRewriter &rewriter, linalg::LinalgOp convOp, int tilingFactor, + ConvBuilderFn convBuilder = defaultConvBuilderFn) { + Location loc = convOp.getLoc(); + + linalg::ConvolutionDimensions convDims; + auto errString = getMatchConvolutionMessage( + linalg::detail::isConvolutionInterfaceImpl(convOp, &convDims)); + if (!errString.empty()) return failure(); + + if (convDims.inputChannel.size() > 1) return failure(); + + if (convDims.outputChannel.size() > 1) return failure(); + + // TODO: Support depthwise convolutions + if (!convDims.depth.empty()) return failure(); + + Value input = convOp->getOperand(0); + Value filter = convOp->getOperand(1); + Value output = convOp->getOperand(2); + + auto inputMap = convOp.getIndexingMapsArray()[0]; + auto filterMap = convOp.getIndexingMapsArray()[1]; + auto outputMap = convOp.getIndexingMapsArray()[2]; + + auto inputIndices = + collectChannelTransposeIndices(inputMap, {convDims.inputChannel}); + auto filterIndices = collectChannelTransposeIndices( + filterMap, {convDims.inputChannel, convDims.outputChannel}); + auto outputIndices = + collectChannelTransposeIndices(outputMap, {convDims.outputChannel}); + + // Don't transpose if there's no change to the op. + if (isInnerIdentityIndices(inputIndices, inputMap.getNumResults()) && + isInnerIdentityIndices(filterIndices, filterMap.getNumResults()) && + isInnerIdentityIndices(outputIndices, outputMap.getNumResults())) + return failure(); + + int nDims = outputMap.getNumDims(); + llvm::DenseMap innerDimsToDomainDims; + for (auto [index, dim] : llvm::enumerate(convDims.inputChannel)) { + innerDimsToDomainDims[dim] = nDims + index; + } + for (auto [index, dim] : llvm::enumerate(convDims.outputChannel)) { + innerDimsToDomainDims[dim] = nDims + index + convDims.inputChannel.size(); + } + + auto [transposedInput, inputPack, transposedInputMap] = + createTransposeAsTensorPack(rewriter, loc, input, inputMap, inputIndices, + tilingFactor, innerDimsToDomainDims); + auto [transposedFilter, filterPack, transposedFilterMap] = + createTransposeAsTensorPack(rewriter, loc, filter, filterMap, + filterIndices, tilingFactor, + innerDimsToDomainDims); + auto [transposedOutput, outputPack, transposedOutputMap] = + createTransposeAsTensorPack(rewriter, loc, output, outputMap, + outputIndices, tilingFactor, + innerDimsToDomainDims); + + // Don't transpose if there's no change to the op. + if (transposedInputMap == inputMap && transposedFilterMap == filterMap && + transposedOutputMap == outputMap) + return failure(); + + Value convDest = transposedOutput; + if (auto fillOp = output.getDefiningOp()) { + if (outputPack) { + auto outputDest = outputPack->getDest().getDefiningOp(); + auto elementType = outputDest.getType().getElementType(); + + auto dimToTileMapping = outputPack->getDimAndTileMapping(); + SmallVector mixedSizes = outputDest.getMixedSizes(); + SmallVector packedSizes; + for (auto [index, size] : llvm::enumerate(mixedSizes)) + if (!dimToTileMapping.count(index) || tilingFactor > 0) + packedSizes.push_back(size); + + auto emptyOp = + rewriter.create(loc, packedSizes, elementType); + + convDest = rewriter + .create(loc, fillOp.getInputs(), + emptyOp.getResult()) + .result(); + } + } + + SmallVector newDimOrder; + SmallVector newIteratorTypes; + if (tilingFactor <= 0) { + newDimOrder.append(convDims.batch); + newDimOrder.append(convDims.outputImage); + newDimOrder.append(convDims.outputChannel); + newDimOrder.append(convDims.filterLoop); + newDimOrder.append(convDims.inputChannel); + } else { + newIteratorTypes.append(convDims.inputChannel.size(), + utils::IteratorType::reduction); + newIteratorTypes.append(convDims.outputChannel.size(), + utils::IteratorType::parallel); + } + + Value transposedConvResult = + convBuilder(rewriter, loc, convOp, transposedInput, transposedFilter, + convDest, transposedInputMap, transposedFilterMap, + transposedOutputMap, newDimOrder, newIteratorTypes); + + Value returnToNCHW = transposedConvResult; + if (outputPack) { + returnToNCHW = createTransposeAsTensorUnPack( + rewriter, loc, transposedConvResult, *outputPack, tilingFactor); + } + + rewriter.replaceOp(convOp, returnToNCHW); + return success(); +} + +namespace { + +//===================================================================== +// Convolution packing patterns +//===================================================================== + +struct ConvertLinalgConvNchwFchw : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgConvNchwFchw(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, convOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgPoolingNchwMax + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgPoolingNchwMax(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::PoolingNchwMaxOp poolOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, poolOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgPoolingNchwSum + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgPoolingNchwSum(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::PoolingNchwSumOp poolOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, poolOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgConvOp : OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + ConvertLinalgConvOp(MLIRContext *context, int tile, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + tilingFactor(tile) {} + + LogicalResult matchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const override { + if (op->hasAttr(fullTileTransposeMarker)) + return transposeConvLikeLinalgOp(rewriter, op, 0); + return transposeConvLikeLinalgOp(rewriter, op, tilingFactor); + } + + private: + int tilingFactor; +}; + +//===================================================================== +// Propagation patterns +//===================================================================== + +class BubbleUpPackThroughPadOp final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto padOp = packOp.getSource().getDefiningOp(); + if (!padOp) return failure(); + + if (!padOp.getResult().hasOneUse()) return failure(); + + // TODO: Enable padding. + if (packOp.getPaddingValue()) return failure(); + + // TODO: Enable outer dims perm. + if (!packOp.getOuterDimsPerm().empty()) return failure(); + + // We want to move the pack not the insert_slice. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(padOp); + + Location loc = padOp->getLoc(); + auto mixedTiles = packOp.getMixedTiles(); + auto innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + if (!packOp.getDest().getDefiningOp()) return failure(); + + // Bail out if one of the padded dimension is a tiled one. + llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); + llvm::SmallBitVector innerDims(paddedDims.size()); + for (int64_t dim : innerDimsPos) innerDims.flip(dim); + if (paddedDims.anyCommon(innerDims)) return failure(); + + Value paddingVal = padOp.getConstantPaddingValue(); + if (!paddingVal) return failure(); + + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, + outerDimsPerm); + Value packedSource = rewriter.create( + loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, + /*padding=*/std::nullopt, outerDimsPerm); + + // If we have `outer_dims_perms` we need to adjust the padded dimensions. + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector highPad = padOp.getMixedHighPad(); + if (!outerDimsPerm.empty()) { + applyPermutationToVector(lowPad, outerDimsPerm); + applyPermutationToVector(highPad, outerDimsPerm); + } + // Add zero padding for the point loops. + size_t pointLoopsSize = innerDimsPos.size(); + lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + + auto newPadOp = rewriter.create( + loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, + padOp.getNofold()); + rewriter.replaceOp(packOp, newPadOp.getResult()); + return success(); + } +}; + +class BubbleUpPackThroughTensorInsertSlice final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto insertSliceOp = + packOp.getSource().getDefiningOp(); + if (!insertSliceOp) return failure(); + + if (!insertSliceOp.getResult().hasOneUse()) return failure(); + + // TODO: Enable rank reduced slice. + if (insertSliceOp.getSourceType().getRank() != + insertSliceOp.getDestType().getRank()) + return failure(); + + // TODO: Enable padding. + if (packOp.getPaddingValue()) return failure(); + + // TODO: Enable outer dims perm. + if (!packOp.getOuterDimsPerm().empty()) return failure(); + + // We want to move the pack not the insert_slice. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(insertSliceOp); + + Location loc = insertSliceOp->getLoc(); + auto mixedTiles = packOp.getMixedTiles(); + auto innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + Value packOpDest = packOp.getDest(); + if (!packOpDest.hasOneUse()) return failure(); + if (auto emptyOp = packOpDest.getDefiningOp()) { + packOpDest = tensor::PackOp::createDestinationTensor( + rewriter, loc, insertSliceOp.getDest(), mixedTiles, innerDimsPos, + outerDimsPerm); + } else { + DominanceInfo dom(insertSliceOp); + if (!dom.properlyDominates(packOpDest, insertSliceOp)) return failure(); + } + + SmallVector mixedSliceTiles(packOp.getMixedTiles()); + + SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); + SmallVector mixedSizes(insertSliceOp.getMixedSizes()); + SmallVector mixedStrides(insertSliceOp.getMixedStrides()); + + for (auto [index, dimPos, mixedTileSize] : + llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), + innerDimsPos, mixedTiles)) { + if (!getConstantIntValue(mixedStrides[dimPos])) return failure(); + + std::optional constTileSize = getConstantIntValue(mixedTileSize); + if (!constTileSize) return failure(); + + std::optional constOffset = + getConstantIntValue(mixedOffsets[dimPos]); + if (!constOffset) return failure(); + + std::optional constSize = + getConstantIntValue(mixedSizes[dimPos]); + if (!constOffset) return failure(); + + int64_t tileSize = *constTileSize; + int64_t offset = *constOffset; + int64_t size = *constSize; + + if ((size % tileSize != 0 || offset % tileSize != 0) && + (offset / tileSize > (size + offset) / tileSize)) + return failure(); + mixedSliceTiles[index] = + rewriter.getI64IntegerAttr(std::min(size, tileSize)); + mixedOffsets[dimPos] = rewriter.getI64IntegerAttr(offset / tileSize); + mixedSizes[dimPos] = + rewriter.getI64IntegerAttr(std::max(size / tileSize, 1)); + + mixedOffsets.push_back(rewriter.getI64IntegerAttr(offset % tileSize)); + mixedSizes.push_back( + rewriter.getI64IntegerAttr(std::min(size, tileSize))); + mixedStrides.push_back(rewriter.getI64IntegerAttr(1)); + } + + Value newDest = packOpDest; + if (!insertSliceOp.getDest().getDefiningOp()) { + newDest = rewriter.create( + loc, insertSliceOp.getDest(), packOpDest, innerDimsPos, mixedTiles, + /*padding=*/std::nullopt, outerDimsPerm); + } + + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, insertSliceOp.getSource(), mixedSliceTiles, innerDimsPos, + outerDimsPerm); + Value packedSlice = rewriter.create( + loc, insertSliceOp.getSource(), empty, innerDimsPos, mixedSliceTiles, + /*padding=*/std::nullopt, outerDimsPerm); + + rewriter.replaceOpWithNewOp( + packOp, packedSlice, newDest, mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; + +//===================================================================== +// Generalization and folding patterns +//===================================================================== + +template +class GeneralizeUntiledPackOrUnPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOrUnPackOpTy op, + PatternRewriter &rewriter) const override { + if (!op.getMixedTiles().empty()) return failure(); + TransposeIndices perm(op.getOuterDimsPerm()); + if (std::is_same::value) + perm = invertIndices(perm); + rewriter.replaceOpWithNewOp(op, op.getSource(), + op.getDest(), perm); + return success(); + } +}; + +static SmallVector getTilingReassociationMap( + int64_t rank, llvm::DenseMap innerDims) { + SmallVector map; + int64_t nTiled = 0; + for (int64_t i = 0, e = rank; i < e; i++) { + if (innerDims.count(i)) { + map.push_back({i + nTiled++, i + nTiled}); + continue; + } + map.push_back({i + nTiled}); + } + return map; +} + +class GeneralizeUnPermutedPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + if (!packOp.getOuterDimsPerm().empty()) return failure(); + if (packOp.getPaddingValue()) return failure(); + + RankedTensorType srcType = + packOp.getSource().getType().cast(); + int64_t rank = srcType.getRank(); + auto innerDimsPos = packOp.getInnerDimsPos(); + llvm::DenseMap innerDims; + for (auto [index, innerDim] : llvm::enumerate(innerDimsPos)) + innerDims[innerDim] = index; + + llvm::DenseMap innerDimsToExpandedDims; + TransposeIndices perm; + int64_t nTiled = 0; + for (int i = 0, e = rank; i < e; i++) { + perm.push_back(i + nTiled); + if (innerDims.count(i)) innerDimsToExpandedDims[i] = i + ++nTiled; + } + for (auto i : innerDimsPos) perm.push_back(innerDimsToExpandedDims[i]); + + RankedTensorType destType = + packOp.getDest().getType().cast(); + SmallVector destShape(destType.getShape()); + applyPermutationToVector(destShape, invertPermutationVector(perm)); + + auto expand = rewriter.create( + packOp.getLoc(), + RankedTensorType::get(destShape, destType.getElementType()), + packOp.getSource(), getTilingReassociationMap(rank, innerDims)); + + rewriter.replaceOpWithNewOp(packOp, expand, + packOp.getDest(), perm); + return success(); + } +}; + +class GeneralizeUnPermutedUnPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + if (!unpackOp.getOuterDimsPerm().empty()) return failure(); + + if (!unpackOp.getDest().getDefiningOp()) return failure(); + + RankedTensorType destType = + unpackOp.getDest().getType().cast(); + int64_t rank = destType.getRank(); + auto innerDimsPos = unpackOp.getInnerDimsPos(); + llvm::DenseMap innerDims; + for (auto [index, innerDim] : llvm::enumerate(innerDimsPos)) + innerDims[innerDim] = index; + + TransposeIndices perm; + for (int i = 0, e = rank; i < e; i++) { + perm.push_back(i); + if (innerDims.count(i)) perm.push_back(rank + innerDims[i]); + } + + Location loc = unpackOp.getLoc(); + SmallVector mixedSizes = + tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()); + applyPermutationToVector(mixedSizes, perm); + auto elType = getElementTypeOrSelf(unpackOp.getDest()); + + auto emptyOp = rewriter.create(loc, mixedSizes, elType); + + Value transpose = rewriter + .create( + loc, unpackOp.getSource(), emptyOp, perm) + ->getResult(0); + + rewriter.replaceOpWithNewOp( + unpackOp, destType, transpose, + getTilingReassociationMap(rank, innerDims)); + return success(); + } +}; + +class GeneralizeLinalgTransposeOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp op, + PatternRewriter &rewriter) const override { + auto linalgOp = cast(*op); + auto transpose = + rewriter + .create( + op.getLoc(), op.getResult().getType(), op.getInput(), + op.getInit(), linalgOp.getIndexingMapsArray(), + linalgOp.getIteratorTypesArray(), + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + rewriter.replaceOp(op, transpose); + return success(); + } +}; + +class FoldCancellingUnPackPackOps final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + return tensor::UnPackOp::canonicalize(unpackOp, rewriter); + } +}; + +class FoldCancellingPackUnPackOps final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + return tensor::PackOp::canonicalize(packOp, rewriter); + } +}; + +struct ConvertConvToChannelsLastPass + : public ConvertConvToChannelsLastBase { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + LogicalResult initializeOptions(StringRef options) override { + if (failed(Pass::initializeOptions(options))) { + return failure(); + } + tilingFactor = tileSize; + return success(); + } + + void runOnOperation() override { + auto op = getOperation(); + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + if (tilingFactor < 0) { + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + } + patterns.insert(context, tilingFactor); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return true; }); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.add(context); + patterns.insert(context); + patterns.insert>(context); + patterns.insert>( + context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + } + + private: + int64_t tilingFactor; +}; + +} // namespace + +std::unique_ptr createConvertConvToChannelsLastPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index 4a03127604ea..662bdf59d9cf 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -37,6 +37,9 @@ createMakeSingleDispatchForFunctionPass(); // A pass to generalize all conv-like ops. std::unique_ptr createGeneralizeConvolutionsPass(); +// Creates a pass to convert convolutions to channels last and propagate. +std::unique_ptr createConvertConvToChannelsLastPass(); + // A pass to pad linalg ops to the next integer multiple of `paddingSize`. std::unique_ptr createPadLinalgOpsToIntegerMultiplePass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 6e83b0a582e6..db0d89940ef7 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -46,6 +46,18 @@ def GeneralizeConvolutions : let constructor = "mlir::iree_compiler::IREE::createGeneralizeConvolutionsPass()"; } +def ConvertConvToChannelsLast : + Pass<"iree-preprocessing-convert-conv-to-channels-last", ""> { + let summary = "Convert linalg convolutions to channels last."; + let constructor = + "mlir::iree_compiler::IREE::createConvertConvToChannelsLastPass()"; + let options = [ + Option<"tileSize", "tile-size", "int", + /*default=*/"0", + "Specify the tiling factor">, + ]; +} + def PadLinalgOps : Pass<"iree-preprocessing-pad-linalg-ops", ""> { let summary = "Pad linalg ops to the next integer multiple of paddingSize."; From 8d6a76691fee69a17bd8d0335cb05708be64c922 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 10 May 2023 04:46:08 -0400 Subject: [PATCH 12/38] Fix embedded linker flag in python exposed iree-opt --- .../python/iree/compiler/tools/core.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/compiler/bindings/python/iree/compiler/tools/core.py b/compiler/bindings/python/iree/compiler/tools/core.py index cd9829ff0451..e019d7d767fe 100644 --- a/compiler/bindings/python/iree/compiler/tools/core.py +++ b/compiler/bindings/python/iree/compiler/tools/core.py @@ -325,6 +325,43 @@ def query_available_targets(): # Preprocessing for SHARK (for now simply exposes iree-opt) +def build_opt_command_line( + input_file: str, tfs: TempFileSaver, options: CompilerOptions +) -> List[str]: + """Builds a command line for applying specified patterns. + + Args: + input_file: The input file name. + tfs: TempFileSaver. + options: Compiler options. + Returns: + List of strings of command line. + """ + iree_opt = find_tool("iree-opt") + cl = [ + iree_opt, + input_file, + ] + + # Output file. + if options.output_file: + cl.append(f"-o={options.output_file}") + + # Tool paths. + lld_path = find_tool("iree-lld") + cl.append(f"--iree-llvmcpu-embedded-linker-path={lld_path}") + + crash_reproducer_path = tfs.alloc_optional( + "core-reproducer.mlir", export_as=options.crash_reproducer_path + ) + if crash_reproducer_path: + cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") + + cl.extend(options.extra_args) + print(cl) + return cl + + def build_opt_command_line( input_file: str, tfs: TempFileSaver, options: CompilerOptions ) -> List[str]: From 8dbd8b008a704e50d7b9e5b46f443cbef3c29a63 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 20 Jul 2023 11:51:35 -0400 Subject: [PATCH 13/38] Add pattern for bubbling vector.bitcast through an enclosing scf.if <32 bit width types are handled on the SPIR-V side by introducing bitcasts to and from i32 and bubbling them to the center of the kernel hoping to cancel. This adds a pattern for a bitcast on the result of an scf.if, which comes from the way that padding is handled (transfer_read in the `then` branch, else yield a splat constant). --- .../Common/OptimizeVectorTransferPass.cpp | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index d9111c7f9aa7..097bb2b0dfbc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -51,6 +51,84 @@ class TransposeUnitDimToShapeCast } }; +// TODO: Move this upstream +// Hoists a vector.bitcast op to the output of the enclosing scf.if +// +// This transforms IR like: +// %0 = scf.if %1 -> (vector<16xi8>) { +// %2 = memref.load %4[%c0] : memref> +// %3 = vector.bitcast %2 : vector<4xi32> to vector<16xi8> +// scf.yield %3 : vector<16xi8> +// } else { +// scf.yield %cst : vector<16xi8> +// } +// Into: +// %0 = scf.if %1 -> (vector<4xi32>) { +// %2 = memref.load %4[%c0] : memref> +// scf.yield %2 : vector<4xi32> +// } else { +// %3 = vector.bitcast %cst : vector<16xi8> to vector<4xi32> +// scf.yield %0 : vector<4xi32> +// } +// %3 = vector.bitcast %0 : vector<4xi32> to vector<16xi8> +struct BubbleUpBitCastOfScfIf : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + // Bail on more than one result for now. + scf::YieldOp thenYield = ifOp.thenYield(); + if (!thenYield || thenYield.getNumOperands() != 1) + return failure(); + auto bitcastOp = thenYield.getOperand(0).getDefiningOp(); + // Bail out if no bitcast on the if then statement. + if (!bitcastOp) + return failure(); + + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + // Skip 0-D vector. + if (castSrcType.getRank() == 0) + return failure(); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to more elements; + if (castSrcLastDim > castDstLastDim) + return failure(); + + Location loc = ifOp.getLoc(); + + auto bitcastedIfOp = + rewriter.create(loc, castSrcType, ifOp.getCondition()); + bitcastedIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + bitcastedIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + + scf::YieldOp newThenYield = bitcastedIfOp.thenYield(); + auto newBitcastOp = + newThenYield.getOperand(0).getDefiningOp(); + + newThenYield.setOperand(0, newBitcastOp.getSource()); + + auto newBitcast = rewriter.create( + loc, castDstType, bitcastedIfOp.getResult(0)); + + scf::YieldOp elseYield = bitcastedIfOp.elseYield(); + if (elseYield) { + OpBuilder::InsertionGuard elseGuard(rewriter); + rewriter.setInsertionPoint(elseYield); + + Value yieldSrc = elseYield.getOperand(0); + auto elseBitcast = + rewriter.create(loc, castSrcType, yieldSrc); + elseYield.setOperand(0, elseBitcast); + } + rewriter.replaceOp(ifOp, newBitcast); + return success(); + } +}; + static void loopInvariantCodeMotion(func::FuncOp funcOp) { // Walk through all loops in a function in innermost-loop-first order. This // way, we first LICM from the inner loop, and place the ops in @@ -99,6 +177,7 @@ struct OptimizeVectorTransferPass { RewritePatternSet patterns(&getContext()); vector::populateBubbleVectorBitCastOpPatterns(patterns); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } From 8d3ac7486ba73331a76489eb0af8b7018226b54f Mon Sep 17 00:00:00 2001 From: Anush Elangovan Date: Fri, 5 Aug 2022 06:39:45 +0000 Subject: [PATCH 14/38] [CI] Add ROCM builds to the the nightly Build Experimental ROCM builds --- compiler/setup.py | 2 ++ runtime/setup.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/compiler/setup.py b/compiler/setup.py index 01a347d296b9..4ff0ddf68f04 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -259,6 +259,8 @@ def prepare_installation(): "-DPython3_EXECUTABLE={}".format(sys.executable), "-DCMAKE_BUILD_TYPE={}".format(cfg), # TODO(scotttodd): include IREE_TARGET_BACKEND_WEBGPU here (and in env) + get_env_cmake_option("IREE_TARGET_BACKEND_ROCM"), + get_env_cmake_option("IREE_TARGET_BACKEND_OPENCL_SPIRV"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), get_env_cmake_option("IREE_TARGET_BACKEND_ROCM", "ON"), get_env_cmake_option("IREE_ENABLE_LLD", "OFF"), diff --git a/runtime/setup.py b/runtime/setup.py index e561c45ea6e5..345f31fc9291 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -274,7 +274,8 @@ def build_configuration(cmake_build_dir, cmake_install_dir, extra_cmake_args=()) "IREE_HAL_DRIVER_VULKAN", "OFF" if platform.system() == "Darwin" else "ON", ), - get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", ""), + get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", + "" if platform.system() != "Linux" else "rocm"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), ] + list(extra_cmake_args) add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER") From f5f3dd859ed8acf115206e2c8d69191e2930abb4 Mon Sep 17 00:00:00 2001 From: powderluv Date: Sat, 8 Jul 2023 14:07:19 -0700 Subject: [PATCH 15/38] [BUILD] - Remove documentation build before publishing website --- .github/workflows/publish_website.yml | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/.github/workflows/publish_website.yml b/.github/workflows/publish_website.yml index 28803fb23fbf..e41b94627abe 100644 --- a/.github/workflows/publish_website.yml +++ b/.github/workflows/publish_website.yml @@ -44,13 +44,6 @@ jobs: with: python-version: 3.x cache: 'pip' - - id: "gcp-auth" - name: "Authenticating to Google Cloud" - uses: "google-github-actions/auth@v1" - with: - token_format: "access_token" - credentials_json: "${{ secrets.IREE_OSS_GITHUB_RUNNER_BASIC_TRUST_SERVICE_ACCOUNT_KEY }}" - create_credentials_file: false - name: Installing dependencies run: | pip install -r docs/website/requirements.txt @@ -60,14 +53,6 @@ jobs: ./build_tools/scripts/generate_release_index.py \ --repo="${GITHUB_REPOSITORY}" \ --output=docs/website/docs/pip-release-links.html - - name: Building documentation files - run: | - ./build_tools/github_actions/docker_run.sh \ - --env "IREE_CCACHE_GCP_TOKEN=${{ steps.gcp-auth.outputs.access_token }}" \ - --env "IREE_WRITE_REMOTE_CCACHE=1" \ - --env "CCACHE_NAMESPACE=gcr.io/iree-oss/base@sha256:796fb81a11ff7e7d057c93de468b74e48b6a9641aa19b7f7673c2772e8ea3b33" \ - gcr.io/iree-oss/base@sha256:796fb81a11ff7e7d057c93de468b74e48b6a9641aa19b7f7673c2772e8ea3b33 \ - ./docs/website/generate_extra_files.sh - name: Setting git config run: | git config --local user.email "iree-github-actions-bot@google.com" From fc5a2d5d0b6a8caa081499f6b51bc7c2512f3abf Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 20 Jul 2023 12:43:05 -0400 Subject: [PATCH 16/38] Drop CODEOWNERS to prevent sending review requests for SHARK-Runtime --- .github/CODEOWNERS | 83 ---------------------------------------------- 1 file changed, 83 deletions(-) delete mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 4b13c668f618..000000000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,83 +0,0 @@ -# Codeowners for IREE Github Repository. -# The listed owners will automatically be added as reviewers to PRs that modify -# paths matching the specified patterns. -# Refer to https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners -# for syntax of this file (tl;dr: syntax is like .gitignore. Last matching rule -# takes precedence). -# Because of the precedence, rules for directories are listed topologically. -# @ghost is used to make a pattern have no owners. It is a sentinel GitHub user -# that takes the place of deleted users. - -# No global owners because we don't really want e.g. changing the root -# CMakeLists.txt file to always ping a bunch of people. - -# Third-Party Code -/.gitmodules @ScottTodd @stellaraccident -/third_party/ @ScottTodd @stellaraccident -# Except for routinely-updated submodules -/third_party/llvm-project @ghost -/third_party/llvm-project.branch-pin @ghost -/third_party/stablehlo @ghost -/third_party/torch-mlir @ghost - -# Bindings -/runtime/bindings/python/ @stellaraccident -/runtime/bindings/tflite/ @benvanik - -# Integrations -/integrations/ @benvanik @stellaraccident -/integrations/tensorflow/ @stellaraccident -/integrations/tensorflow/test/**/iree_tfl_tests/ @rsuderman - -# Experimental -# It's experimental, but we still don't want any old directory added here. -/experimental/ @benvanik @stellaraccident -/experimental/cpu_ukernel/ @bjacob -/experimental/cuda2/ @antiagainst -/experimental/dispatch_profiler/ @manishucsd -/experimental/rocm/ @benvanik -/experimental/web/ @ScottTodd -/experimental/webgpu/ @benvanik @ScottTodd - -# Infra Top-Level Directories -/build_tools/ @ScottTodd @pzread -/build_tools/benchmarks/ @antiagainst @pzread -/build_tools/python/ @pzread -/build_tools/python_deploy/ @stellaraccident -/build_tools/scripts/ @ScottTodd -/build_tools/third_party/ @ScottTodd @stellaraccident -/.github/ @ScottTodd - -# llvm-external-projects -/llvm-external-projects/ @stellaraccident -/llvm-external-projects/iree-dialects/ @MaheshRavishankar -/llvm-external-projects/iree-dialects/**/Dialect/LinalgExt/ @hanhanW @MaheshRavishankar -/llvm-external-projects/iree-dialects/test/iree_linalgext @hanhanW @MaheshRavishankar - -# Other Top-Level Directories -/docs/ @ScottTodd -/samples/ @ScottTodd -/tools/ @benvanik - -# Compiler -/compiler/src/iree/compiler/ @benvanik -/compiler/src/iree/compiler/Codegen/ @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/Common @hanhanW @dcaballe -/compiler/src/iree/compiler/Codegen/Common/GPU @antiagainst @qedawkins -/compiler/src/iree/compiler/Codegen/LLVMCPU/ @dcaballe @hanhanW @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/LLVMGPU/ @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/SPIRV/ @antiagainst @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/TransformStrategies/ @qedawkins @MaheshRavishankar -/compiler/src/iree/compiler/ConstEval/ @hanhanW @stellaraccident -/compiler/src/iree/compiler/Dialect/Flow/ @hanhanW @MaheshRavishankar -/compiler/src/iree/compiler/Dialect/Vulkan/ @antiagainst -/compiler/src/iree/compiler/GlobalOptimization/ @hanhanW -/compiler/src/iree/compiler/InputConversion/ @MaheshRavishankar @stellaraccident -/compiler/plugins/input/StableHLO/ @hanhanW @MaheshRavishankar @rsuderman -/compiler/plugins/input/TOSA/ @MaheshRavishankar @rsuderman - -# Runtime -/runtime/src/iree/ @benvanik -/runtime/src/iree/hal/cts/ @ScottTodd -/runtime/src/iree/hal/drivers/metal/ @antiagainst -/runtime/src/iree/hal/drivers/vulkan/ @antiagainst @ScottTodd From fcae85717cb80330e14cffb8998b23debc6431ab Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 10 Aug 2023 11:44:55 -0700 Subject: [PATCH 17/38] [Distributed] Rudimentary distributed Python API (#64) * Add rudimentary non-production distributed Python API * Distributed execution validation Add functionality that validates distributed StableHLO is producing the same results as non-distributed. * Add execution time measurement * Distributed Python API: add call_count to run_ranks * Add setup script for distributed Python API * Add JAX to install setup --------- Co-authored-by: Boian Petkantchin --- runtime/bindings/python/CMakeLists.txt | 5 + .../bindings/python/iree/runtime/__init__.py | 1 + .../iree/runtime/distributed/__init__.py | 9 + .../iree/runtime/distributed/distributed.py | 86 +++++++ .../iree/runtime/distributed/run_rank.py | 132 +++++++++++ .../python/iree/runtime/distributed/setup.sh | 15 ++ .../distributed/sharding_pass_validation.py | 210 ++++++++++++++++++ .../python/iree/runtime/distributed/utils.py | 26 +++ 8 files changed, 484 insertions(+) create mode 100644 runtime/bindings/python/iree/runtime/distributed/__init__.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/distributed.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/run_rank.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/setup.sh create mode 100644 runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/utils.py diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt index 44112ce51db9..a863fd2a5fe0 100644 --- a/runtime/bindings/python/CMakeLists.txt +++ b/runtime/bindings/python/CMakeLists.txt @@ -150,6 +150,11 @@ iree_py_library( "iree/_runtime/scripts/iree_run_trace/__main__.py" "iree/_runtime/scripts/iree_run_module/__main__.py" "iree/_runtime/scripts/iree_tracy_capture/__main__.py" + "iree/runtime/distributed/__init__.py" + "iree/runtime/distributed/distributed.py" + "iree/runtime/distributed/run_rank.py" + "iree/runtime/distributed/sharding_pass_validation.py" + "iree/runtime/distributed/utils.py" PYEXT_DEPS iree_runtime_bindings_python_PyExtRt ) diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py index d594b23c88d0..a9201b863d5f 100644 --- a/runtime/bindings/python/iree/runtime/__init__.py +++ b/runtime/bindings/python/iree/runtime/__init__.py @@ -66,4 +66,5 @@ from .io import * from .tracing import * +from . import distributed from . import flags diff --git a/runtime/bindings/python/iree/runtime/distributed/__init__.py b/runtime/bindings/python/iree/runtime/distributed/__init__.py new file mode 100644 index 000000000000..86ee5db110cc --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .distributed import prepare_shards_io_files, run_ranks + +__all__ = ["prepare_shards_io_files", "run_ranks"] diff --git a/runtime/bindings/python/iree/runtime/distributed/distributed.py b/runtime/bindings/python/iree/runtime/distributed/distributed.py new file mode 100644 index 000000000000..258e517b2cf2 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/distributed.py @@ -0,0 +1,86 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.compiler +import sys +import iree.runtime +from iree.runtime.array_interop import DeviceArray +import os +from numpy.typing import ArrayLike +from typing import List, Tuple +import tempfile +import subprocess +from . import utils + + +def prepare_shards_io_files( + inputs: List[List[ArrayLike]], out_dir: str +) -> Tuple[List[str], List[str]]: + input_filepaths = [] + output_filepaths = [] + for i in range(len(inputs)): + input_filepath = os.path.join(out_dir, f"shard_{i}", "input.npy") + input_filepaths.append(input_filepath) + os.makedirs(os.path.dirname(input_filepath)) + utils.write_numpy_arrays_to_file(filepath=input_filepath, arrays=inputs[i]) + output_filepath = os.path.join(out_dir, f"shard_{i}", "output.npy") + output_filepaths.append(output_filepath) + return input_filepaths, output_filepaths + + +def run_ranks( + num_ranks: int, + module_filepath: str, + function: str, + inputs: List[List[ArrayLike]], + driver: str, + call_count: int = 1, + measure_execution_time: bool = False, + warmup: int = 0, +) -> List[List[ArrayLike]]: + """ + Start all ranks with mpirun. + On all ranks run the function |function| from the given module. + Parameters + ---------- + inputs : Function inputs for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + Returns + ------- + The output of the function for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + """ + with tempfile.TemporaryDirectory() as out_dir: + input_filepaths, output_filepaths = prepare_shards_io_files( + inputs=inputs, out_dir=out_dir + ) + hal_driver = iree.runtime.get_driver(driver) + hal_driver.query_available_devices() + subprocess.check_call( + [ + "mpirun", + "--oversubscribe", + "-n", + str(num_ranks), + sys.executable, + os.path.join(os.path.dirname(__file__), "run_rank.py"), + f"--driver={driver}", + f"--module_filepath={module_filepath}", + f"--function={function}", + f"--call_count={call_count}", + ] + + (["--measure_execution_time"] if measure_execution_time else []) + + [ + f"--warmup={warmup}", + "--inputs", + ] + + input_filepaths + + ["--outputs"] + + output_filepaths + ) + return [ + utils.read_numpy_arrays_from_file(out_file) for out_file in output_filepaths + ] diff --git a/runtime/bindings/python/iree/runtime/distributed/run_rank.py b/runtime/bindings/python/iree/runtime/distributed/run_rank.py new file mode 100644 index 000000000000..7ad00f7256cc --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/run_rank.py @@ -0,0 +1,132 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.compiler +import argparse +import iree.runtime +from iree.runtime.array_interop import DeviceArray +from mpi4py import MPI +import utils +import datetime +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run 1 shard.") + parser.add_argument("--driver", type=str, default="local-task", help="Device URI.") + parser.add_argument( + "--module_filepath", type=str, required=True, help="Path to IREE module." + ) + parser.add_argument( + "--function", type=str, required=True, help="Name of function to call." + ) + parser.add_argument( + "--call_count", + type=int, + default=1, + help="How many times to call the function during time measurement.", + ) + parser.add_argument( + "--measure_execution_time", + action="store_true", + default=False, + help="Measure execution time in seconds f64 and append to results.", + ) + parser.add_argument( + "--warmup", + type=int, + default=0, + help="How many warmup calls to do before the actual call that generates the result.", + ) + parser.add_argument( + "--inputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module inputs for all ranks in npy format.", + ) + parser.add_argument( + "--outputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module outputs form all ranks in npy format.", + ) + return parser.parse_args() + + +def run_module( + device: iree.runtime.HalDevice, + module_filepath: str, + function: str, + call_count: int, + input_filepath: str, + output_filepath: str, + measure_execution_time: bool, + warmup: int, +): + config = iree.runtime.Config(device=device) + with open(module_filepath, "rb") as f: + vm_flatbuffer = f.read() + vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer) + bound_module = iree.runtime.load_vm_module(vm_module, config) + input_args = utils.read_numpy_arrays_from_file(input_filepath) + input_args_on_device = [ + iree.runtime.asdevicearray(device, arr) for arr in input_args + ] + for _ in range(warmup): + getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + # Sync all ranks + MPI.COMM_WORLD.barrier() + start_time = datetime.datetime.now() + assert call_count > 0 + for _ in range(call_count): + results = getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + end_time = datetime.datetime.now() + if isinstance(results, DeviceArray): + results = [results] + if measure_execution_time: + if isinstance(results, tuple): + results = list(results) + results.append( + np.array((end_time - start_time).total_seconds() / call_count, dtype=float) + ) + utils.write_numpy_arrays_to_file(filepath=output_filepath, arrays=results) + + +def run_rank( + driver: str, + module_filepath: str, + function: str, + inputs: str, + outputs: str, + call_count: int, + measure_execution_time: bool, + warmup: int, +): + rank = MPI.COMM_WORLD.Get_rank() + hal_driver = iree.runtime.get_driver(driver) + device_infos = hal_driver.query_available_devices() + device = hal_driver.create_device( + device_infos[rank % len(device_infos)]["device_id"] + ) + run_module( + device=device, + module_filepath=module_filepath, + function=function, + call_count=call_count, + input_filepath=inputs[rank], + output_filepath=outputs[rank], + measure_execution_time=measure_execution_time, + warmup=warmup, + ) + + +if __name__ == "__main__": + args = parse_args() + run_rank(**vars(args)) diff --git a/runtime/bindings/python/iree/runtime/distributed/setup.sh b/runtime/bindings/python/iree/runtime/distributed/setup.sh new file mode 100644 index 000000000000..83dca488caa4 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/setup.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +distribution=$(. /etc/os-release;echo $ID$VERSION_ID | sed -e 's/\.//g') +wget -O /tmp/cuda-keyring_1.0-1_all.deb \ + https://developer.download.nvidia.com/compute/cuda/repos/$distribution/x86_64/cuda-keyring_1.0-1_all.deb +sudo dpkg -i /tmp/cuda-keyring_1.0-1_all.deb +sudo apt update +# For CMake to find CUDA when using LLD. +sudo apt -y install lld + +sudo apt -y install libopenmpi-dev +sudo apt -y install libnccl-dev=2.18.1-1+cuda12.1 +pip install mpi4py jax[cpu] diff --git a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py new file mode 100644 index 000000000000..446520451623 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py @@ -0,0 +1,210 @@ +import iree.compiler +import iree.runtime +import os +from iree.runtime.distributed import run_ranks +import subprocess +from pathlib import Path +from jax._src.lib import xla_client +from jaxlib.xla_client import HloSharding +from typing import List, Tuple, Union +from numpy.typing import ArrayLike +import jax +from jax._src.sharding_impls import GSPMDSharding +import jax._src.interpreters.pxla as pxla +import numpy as np +from datetime import timedelta + +xla_extension = xla_client._xla + + +def compile_mlir(mlir_filepath: str, output_filepath: str, use_cache: bool, **kwargs): + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + iree.compiler.compile_file( + input_file=mlir_filepath, output_file=output_filepath, **kwargs + ) + + +def extract_args_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + return [ + HloSharding.from_proto(sharding) + for sharding in xla_computation.get_hlo_module().spmd_parameters_shardings + ] + + +def extract_results_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + sharding = HloSharding.from_proto( + xla_computation.get_hlo_module().spmd_output_sharding + ) + if len(sharding.tuple_elements()): + return sharding.tuple_elements() + else: + return [sharding] + + +def shard_arg(arg: ArrayLike, sharding: HloSharding) -> List[ArrayLike]: + gspmd_sharding = GSPMDSharding(devices=jax.local_devices(), op_sharding=sharding) + indices = gspmd_sharding.devices_indices_map(arg.shape).values() + sharded_array = pxla.shard_arg( + arg, devices=jax.local_devices(), arg_indices=indices, sharding=gspmd_sharding + ) + return [shard.data for shard in sharded_array.global_shards] + + +def shard_args( + args: List[ArrayLike], shardings: List[HloSharding] +) -> List[List[ArrayLike]]: + assert len(args) == len(shardings) + return [shard_arg(arg, sharding) for arg, sharding in zip(args, shardings)] + + +def assemble_shards(shards: List[ArrayLike], sharding: HloSharding) -> ArrayLike: + if sharding.is_replicated(): + return shards[0] + else: + raise NotImplementedError() + + +def propagate_shardings_and_spmd_partition( + mlir_filepath: str, + output_filepath: str, + num_devices: int, + use_cache: bool, + allow_spmd_sharding_propagation_to_output: int = 1, +): + res = subprocess.run( + [ + "stablehlo-opt", + ( + "--pass-pipeline=builtin.module(stablehlo-xla-sharding-propagation-and-spmd-partitioner{" + "is_spmd=1 " + f"allow_spmd_sharding_propagation_to_output={allow_spmd_sharding_propagation_to_output} " + "allow_spmd_sharding_propagation_to_parameters=1 " + f"num_partitions={num_devices} " + "num_replicas=1})" + ), + mlir_filepath, + ], + check=True, + stdout=subprocess.PIPE, + ) + Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + with open(output_filepath, "wb") as f: + f.write(res.stdout) + + +def swap_shard_axis(arrays: List[ArrayLike]) -> List[List[ArrayLike]]: + """Swap axis 0 with 1.""" + if len(arrays) == 0: + return [] + expected_shards = len(arrays[0]) + res = [[] for _ in range(expected_shards)] + for arr in arrays: + assert len(arr) == expected_shards + for shard in range(expected_shards): + res[shard].append(arr[shard]) + return res + + +def execute_distributed( + num_ranks: int, + mlir_filepath: str, + iree_module_filepath: str, + function: str, + inputs: List[ArrayLike], + driver: str, + measure_execution_time: bool = False, +) -> Union[List[ArrayLike], Tuple[List[ArrayLike], timedelta]]: + with open(mlir_filepath, "r") as f: + mlir_str = f.read() + xla_computation = xla_extension.mlir.mlir_module_to_xla_computation( + mlir_module=mlir_str, use_tuple_args=False, return_tuple=False + ) + args_sharding = extract_args_sharding(xla_computation) + results_sharding = extract_results_sharding(xla_computation) + sharded_args = shard_args(args=inputs, shardings=args_sharding) + sharded_args = swap_shard_axis(sharded_args) + sharded_results = run_ranks( + num_ranks=num_ranks, + module_filepath=iree_module_filepath, + function=function, + inputs=sharded_args, + driver=driver, + ) + sharded_results = swap_shard_axis(sharded_results) + if measure_execution_time: + sharded_results, execution_times = sharded_results + res = [ + assemble_shards(shards=result_shards, sharding=sharding) + for result_shards, sharding in zip(sharded_results, results_sharding) + ] + if measure_execution_time: + res = res, timedelta(seconds=np.max(execution_times)) + return res + + +def validate_sharding_passes( + mlir_filepath: str, + mlir_with_sharding_annotations_filepath: str, + inputs: List[ArrayLike], + function: str, + num_devices: int, + use_cache: bool, + driver: str, + target_backend: str, + output_prefix_path: str, + allow_spmd_sharding_propagation_to_output: int = 1, +): + # Single instance. + iree_module_filepath = ( + f"{output_prefix_path}{os.path.basename(mlir_filepath)}.{driver}.vmfb" + ) + compile_mlir( + mlir_filepath=mlir_filepath, + output_filepath=iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + iree_module = iree.runtime.load_vm_flatbuffer_file( + path=iree_module_filepath, driver=driver + ) + results = iree_module[function](*inputs) + if isinstance(results, iree.runtime.DeviceArray): + results = [results] + + # Distributed. + spmd_mlir_filepath = f"{output_prefix_path}{os.path.basename(mlir_with_sharding_annotations_filepath)}.spmd.mlir" + propagate_shardings_and_spmd_partition( + mlir_filepath=mlir_with_sharding_annotations_filepath, + output_filepath=spmd_mlir_filepath, + num_devices=num_devices, + use_cache=use_cache, + allow_spmd_sharding_propagation_to_output=allow_spmd_sharding_propagation_to_output, + ) + spmd_iree_module_filepath = f"{output_prefix_path}{os.path.basename(spmd_mlir_filepath)}.{target_backend}.vmfb" + compile_mlir( + mlir_filepath=spmd_mlir_filepath, + output_filepath=spmd_iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + spmd_results = execute_distributed( + num_ranks=num_devices, + mlir_filepath=spmd_mlir_filepath, + iree_module_filepath=spmd_iree_module_filepath, + function=function, + inputs=inputs, + driver=driver, + ) + + assert len(results) == len(spmd_results) + for result, spmd_result in zip(results, spmd_results): + np.testing.assert_allclose(result, spmd_result, atol=1e-7) diff --git a/runtime/bindings/python/iree/runtime/distributed/utils.py b/runtime/bindings/python/iree/runtime/distributed/utils.py new file mode 100644 index 000000000000..3581baf354f8 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/utils.py @@ -0,0 +1,26 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from numpy.typing import ArrayLike +from typing import List +import numpy as np + + +def read_numpy_arrays_from_file(filepath: str) -> List[ArrayLike]: + res = [] + with open(filepath, "rb") as f: + while True: + try: + res.append(np.load(f)) + except EOFError: + break + return res + + +def write_numpy_arrays_to_file(filepath: str, arrays: List[ArrayLike]): + with open(filepath, "wb") as f: + for arr in arrays: + np.save(f, np.asarray(arr), allow_pickle=False) From 4e979d4720db2e18b3b9d375bf3520db6c89873f Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 10 Aug 2023 13:13:19 -0700 Subject: [PATCH 18/38] [Distributed] Add example to run a simple model across 2 GPUs --- samples/distributed/example.py | 55 ++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 samples/distributed/example.py diff --git a/samples/distributed/example.py b/samples/distributed/example.py new file mode 100644 index 000000000000..ff0989403df9 --- /dev/null +++ b/samples/distributed/example.py @@ -0,0 +1,55 @@ +from iree.runtime.distributed import run_ranks +import iree.compiler +import tempfile +import numpy as np +import os + +""" +Example of distributed execution across 2 devices of a small model +with just an all-reduce operation. +all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) -> [6, 8, 10, 12]. + +Dependecies at: +runtime/bindings/python/iree/runtime/distributed/setup.sh +""" +mlir = """ + func.func @all_reduce_sum(%input : tensor<4xf32>) -> tensor<4xf32> { + %out = "stablehlo.all_reduce"(%input) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %sum = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %sum : tensor + }) {channel_handle = #stablehlo.channel_handle, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + use_global_device_ids} : (tensor<4xf32>) -> tensor<4xf32> + return %out : tensor<4xf32> + } +""" + +inputs = [ + [np.array([1, 2, 3, 4], dtype=np.float32)], + [np.array([5, 6, 7, 8], dtype=np.float32)], +] + +for rank in range(len(inputs)): + print(f"Rank {rank} argument = {inputs[rank]}") + +with tempfile.TemporaryDirectory() as tmp_dir: + module_filepath = os.path.join(tmp_dir, "module.vmfb") + iree.compiler.tools.compile_str( + input_str=mlir, + output_file=module_filepath, + target_backends=["cuda"], + input_type="stablehlo", + ) + + num_ranks = len(inputs) + # Ranks on the 0th axis. + outputs = run_ranks( + num_ranks=num_ranks, + function="all_reduce_sum", + driver="cuda", + module_filepath=module_filepath, + inputs=inputs, + ) + for rank in range(num_ranks): + print(f"Rank {rank} result = {outputs[rank]}") From bfa3aca774e9a0e0f04205a42286a14179bcd87f Mon Sep 17 00:00:00 2001 From: Anush Elangovan Date: Wed, 14 Sep 2022 05:31:02 -0700 Subject: [PATCH 19/38] [CI] Switch to GHA linux runners, remove TF builds, move macOS to self-hosted, clean macos bindist Drop instrumented builds and Python < 3.11 Add Upstream sync CI This fixes the problem of potentially dropping commits that have been submitted while an automatic rebase with upstream IREE is goining on. [CI] Fix macos clean up logic Fixes the macos builder. --- .github/workflows/build_package.yml | 44 ++++-------- .github/workflows/sync.yml | 69 +++++++++++++++++++ .../python_deploy/build_linux_packages.sh | 2 +- build_tools/scripts/get_latest_green.sh | 11 --- 4 files changed, 83 insertions(+), 43 deletions(-) create mode 100644 .github/workflows/sync.yml diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 7e044ea4a1d6..aa5f68d46efc 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -39,47 +39,25 @@ jobs: matrix: include: # Ubuntu packages. - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 + - runs-on: ubuntu-latest + build-family: linux build-package: main-dist-linux experimental: false - - runs-on: [self-hosted, arm64, os-family=Linux, runner-group=postsubmit] - build-family: linux-aarch64 - build-package: main-dist-linux - experimental: true - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 + - runs-on: ubuntu-latest + build-family: linux build-package: py-compiler-pkg experimental: false - - runs-on: [self-hosted, arm64, os-family=Linux, runner-group=postsubmit] - build-family: linux-aarch64 - build-package: py-compiler-pkg - experimental: true - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 - build-package: py-runtime-pkg - experimental: false - - runs-on: [self-hosted, arm64, os-family=Linux, runner-group=postsubmit] - build-family: linux-aarch64 + - runs-on: ubuntu-latest + build-family: linux build-package: py-runtime-pkg - experimental: true - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 - build-package: py-tf-compiler-tools-pkg experimental: false - # MacOS packages. - - runs-on: - - ${{ github.repository == 'openxla/iree' && 'self-hosted' || 'macos-11' }} - - os-family=macOS - - runner-group=postsubmit + # Macos packages. + - runs-on: MacStudio build-family: macos build-package: py-compiler-pkg experimental: true - - runs-on: - - ${{ github.repository == 'openxla/iree' && 'self-hosted' || 'macos-11' }} - - os-family=macOS - - runner-group=postsubmit + - runs-on: MacStudio build-family: macos build-package: py-runtime-pkg experimental: true @@ -106,6 +84,10 @@ jobs: path: "c" # Windows can hit path length limits, so use a short path. submodules: true ref: ${{ github.event.inputs.commit }} + - uses: actions/setup-python@v4 + if: "matrix.build-family != 'macos'" + with: + python-version: '3.11' ########################################################################## # OS specific setup diff --git a/.github/workflows/sync.yml b/.github/workflows/sync.yml new file mode 100644 index 000000000000..2fb41e7b7c68 --- /dev/null +++ b/.github/workflows/sync.yml @@ -0,0 +1,69 @@ +name: 'Sync Upstream' + +on: + workflow_dispatch: + schedule: + - cron: '0 * * * *' + +jobs: + sync_upstream: + name: 'Sync Upstream' + runs-on: ubuntu-latest + steps: + - name: Checking out repository + uses: actions/checkout@v3 + with: + token: ${{ secrets.CI_WRITE_TOKEN }} + repository: nod-ai/shark-runtime + ref: main + fetch-depth: 0 + + - name: Setup git + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "SHARK bot" + + - name: Update main upstream + run: | + set -ex + git remote add upstream https://github.com/iree-org/iree + git pull --ff-only upstream main + + - name: Pushing changes + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.CI_WRITE_TOKEN }} + branch: main + repository: nod-ai/shark-runtime + + rebase_shark: + name: 'Rebase SHARK' + runs-on: ubuntu-latest + steps: + - name: Checking out repository + uses: actions/checkout@v3 + with: + token: ${{ secrets.CI_WRITE_TOKEN }} + repository: nod-ai/shark-runtime + ref: shark + fetch-depth: 0 + + - name: Setup git + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "SHARK bot" + + - name: Update shark upstream + run: | + set -ex + git remote add upstream https://github.com/iree-org/iree + git fetch upstream + git rebase upstream/main + + - name: Pushing changes + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.CI_WRITE_TOKEN }} + branch: shark + repository: nod-ai/shark-runtime + force_with_lease: true diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 9116db182c2d..3363d35b421c 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -65,7 +65,7 @@ this_dir="$(cd $(dirname $0) && pwd)" script_name="$(basename $0)" repo_root=$(cd "${this_dir}" && find_git_dir_parent) manylinux_docker_image="${manylinux_docker_image:-$(uname -m | awk '{print ($1 == "aarch64") ? "quay.io/pypa/manylinux_2_28_aarch64" : "ghcr.io/nod-ai/manylinux_x86_64:main" }')}" -python_versions="${override_python_versions:-cp39-cp39 cp310-cp310 cp311-cp311}" +python_versions="${override_python_versions:-cp311-cp311}" output_dir="${output_dir:-${this_dir}/wheelhouse}" packages="${packages:-iree-runtime iree-compiler}" package_suffix="${package_suffix:-}" diff --git a/build_tools/scripts/get_latest_green.sh b/build_tools/scripts/get_latest_green.sh index 979acb2b6ec0..ea08d08125ce 100755 --- a/build_tools/scripts/get_latest_green.sh +++ b/build_tools/scripts/get_latest_green.sh @@ -36,17 +36,6 @@ function get_latest_green() { local query_string="$(IFS="&" ; echo "${query_params[*]}")" local all_passing="true" - for workflow in "${REQUIRED_WORKFLOWS[@]}"; do - local successful_run_count="$(\ - gh api --jq '.total_count' \ - "/repos/openxla/iree/actions/workflows/${workflow}/runs?${query_string}" \ - )" - # Any successful run of the workflow (including reruns) is OK. - if (( successful_run_count==0 )); then - all_passing="false" - break - fi - done if [[ "${all_passing}" == true ]]; then echo "${commit}" return 0 From c14ae5e24a4a96f90100aafad043f2475f888236 Mon Sep 17 00:00:00 2001 From: powderluv Date: Sat, 3 Jun 2023 06:45:38 -0700 Subject: [PATCH 20/38] [CI] Add AArch64 builder, disable tests --- .github/workflows/build_package.yml | 67 +++++++++-- .../validate_and_publish_release.yml | 106 ------------------ compiler/setup.py | 4 +- runtime/setup.py | 2 +- 4 files changed, 59 insertions(+), 120 deletions(-) diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index aa5f68d46efc..0d8b588c4043 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -39,15 +39,11 @@ jobs: matrix: include: # Ubuntu packages. - - runs-on: ubuntu-latest - build-family: linux - build-package: main-dist-linux - experimental: false - - runs-on: ubuntu-latest + - runs-on: icelake build-family: linux build-package: py-compiler-pkg experimental: false - - runs-on: ubuntu-latest + - runs-on: icelake build-family: linux build-package: py-runtime-pkg experimental: false @@ -73,19 +69,35 @@ jobs: build-package: py-runtime-pkg experimental: true + # Linux AArch64 packages. + - runs-on: linux-aarch64 + build-family: linux-aarch64 + build-package: py-compiler-pkg + experimental: false + - runs-on: linux-aarch64 + build-family: linux-aarch64 + build-package: py-runtime-pkg + experimental: false + + env: # These are also set in: build_tools/python_deploy/build_linux_packages.sh MANYLINUX_X86_64_IMAGE: ghcr.io/nod-ai/manylinux_x86_64:main MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux_2_28_aarch64 steps: + # Docker may leave root owned files + - name: Chown user + if: "matrix.build-family == 'linux-aarch64' || matrix.build-family == 'linux'" + run: | + sudo chown -R $USER:$USER $GITHUB_WORKSPACE - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 with: path: "c" # Windows can hit path length limits, so use a short path. submodules: true ref: ${{ github.event.inputs.commit }} - uses: actions/setup-python@v4 - if: "matrix.build-family != 'macos'" + if: "matrix.build-family == 'windows'" with: python-version: '3.11' @@ -222,6 +234,16 @@ jobs: [ -e ./bindist/* ] && rm ./bindist/* ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build compiler wheels (Linux-AArch64) + if: "matrix.build-package == 'py-compiler-pkg' && matrix.build-family == 'linux-aarch64'" + shell: bash + env: + package_suffix: ${{ github.event.inputs.package_suffix }} + packages: "iree-compiler" + output_dir: "${{ github.workspace }}/bindist" + run: | + ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build compiler wheels (MacOS) if: "matrix.build-package == 'py-compiler-pkg' && matrix.build-family == 'macos'" shell: bash @@ -270,10 +292,10 @@ jobs: path: ./bindist/* retention-days: 5 - # TODO: Upload the tar.bz2 files too when ready - - name: Upload Release Assets - if: github.event.inputs.release_id != '' - id: upload-release-assets + # TODO: One Window Release builds we build both compiler+runtime + - name: Upload Release Assets (Windows) + if: "github.event.inputs.release_id != '' && matrix.build-family == 'windows'" + id: upload-release-assets-windows uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 env: GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} @@ -282,6 +304,29 @@ jobs: # Only upload iree artifacts. assets_path: ./bindist/iree*.* + # TODO: Upload the tar.bz2 files too when ready + - name: Upload Release Assets (Compiler) + if: "github.event.inputs.release_id != '' && matrix.build-package == 'py-compiler-pkg' && matrix.build-family != 'windows'" + id: upload-release-assets-compiler + uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 + env: + GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + # Only upload iree artifacts. + assets_path: ./bindist/iree_compiler*.* + + - name: Upload Release Assets (Runtime) + if: "github.event.inputs.release_id != '' && matrix.build-package == 'py-runtime-pkg' && matrix.build-family != 'windows'" + id: upload-release-assets-runtime + uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 + env: + GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + # Only upload iree artifacts. + assets_path: ./bindist/iree_runtime*.* + validate_and_publish: name: "Trigger validate and publish release" needs: build_packages diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml index 41d19b40c9f9..e86c2665ea6e 100644 --- a/.github/workflows/validate_and_publish_release.yml +++ b/.github/workflows/validate_and_publish_release.yml @@ -16,100 +16,8 @@ on: required: true jobs: - validate_packages: - name: "Validate packages" - # TODO(jennik): Look into testing windows and macos builds. - runs-on: ubuntu-20.04 - steps: - - name: Download packages - id: download_packages - uses: dawidd6/action-download-artifact@5e780fc7bbd0cac69fc73271ed86edf5dcb72d67 # v2.26.0 - with: - github_token: ${{secrets.WRITE_ACCESS_TOKEN}} - workflow: build_package.yml - run_id: ${{ github.event.inputs.build_run_id }} - - name: Extract and display downloaded files - run: | - tar -xf artifact/iree-dist-${{ github.event.inputs.package_version }}-linux-x86_64.tar.xz - pwd - ls -R - - name: Set up python - id: set_up_python - uses: actions/setup-python@d27e3f3d7c64b4bbf8e4abfb9b63b83e846e0435 # v4.5.0 - with: - python-version: "3.9" - - name: Install python packages - id: install_python_packages - run: | - python -m pip install -f file://$PWD/artifact/ iree-compiler iree-runtime iree-tools-tflite iree-tools-tf - - name: Validate IREE Runtime Package - id: validate_runtime_package - run: | - echo "Testing default runtime:" - python -m iree.runtime._package_test - echo "Testing tracy runtime:" - # GH runners don't expose the TSC but we want to make sure the basic packaging - # works, so override the check with TRACY_NO_INVARIANT_CHECK=1 (per instructions - # if this is left off). - TRACY_NO_INVARIANT_CHECK=1 IREE_PY_RUNTIME=tracy \ - python -m iree.runtime._package_test - # Binaries from the tarball - - name: Run iree-benchmark-module - id: run_iree_benchmark_module - run: ./bin/iree-benchmark-module --help - - name: Run iree-benchmark-trace - id: run_iree_benchmark_trace - run: ./bin/iree-benchmark-trace --help - - name: Run iree-dump-module - id: run_iree_dump_module - run: ./bin/iree-dump-module --help - - name: Run iree-cpuinfo - id: run_iree_cpuinfo - run: ./bin/iree-cpuinfo - - name: Run iree-flatcc-cli - id: run_iree_flatcc_cli - run: ./bin/iree-flatcc-cli --help - - name: Run iree-opt - id: run_iree_opt - run: ./bin/iree-opt --help - - name: Run iree-run-mlir - id: run_iree_run_mlir - run: ./bin/iree-run-mlir --help - - name: Run iree-run-module - id: run_iree_run_module - run: ./bin/iree-run-module --help - - name: Run iree-run-trace - id: run_iree_run_trace - run: ./bin/iree-run-trace --help - - name: Run iree-tblgen - id: run_iree_tblgen - run: ./bin/iree-tblgen --help - - name: Run iree-compile - id: run_iree-compile - run: ./bin/iree-compile --help - # Console scripts from the wheels. - - name: Py iree-run-module - id: py_iree-run-module - run: iree-run-module --help - - name: Py iree-run-trace - id: py_iree-run-trace - run: iree-run-trace --help - - name: Py iree-benchmark-module - id: py_iree_benchmark_module - run: iree-benchmark-module --help - - name: Py iree-benchmark-trace - id: py_iree_benchmark_trace - run: iree-benchmark-trace --help - - name: Py iree-dump-module - id: py_iree_dump_module - run: iree-dump-module --help - - name: Py iree-cpuinfo - id: py_iree_cpuinfo - run: iree-cpuinfo - publish_release: name: "Publish release" - needs: validate_packages runs-on: ubuntu-20.04 steps: - name: Publish Release @@ -120,17 +28,3 @@ jobs: with: release_id: ${{ github.event.inputs.release_id }} - - name: Checking out repository - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 - with: - token: ${{ secrets.WRITE_ACCESS_TOKEN }} - # Get all history. Otherwise the latest-snapshot branch can't be - # fast-forwarded. - fetch-depth: 0 - - - name: Updating latest-snapshot branch - uses: ad-m/github-push-action@40bf560936a8022e68a3c00e7d2abefaf01305a6 # v0.6.0 - with: - github_token: ${{ secrets.WRITE_ACCESS_TOKEN }} - branch: latest-snapshot - force: true diff --git a/compiler/setup.py b/compiler/setup.py index 4ff0ddf68f04..df229d993143 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -251,11 +251,11 @@ def prepare_installation(): "-GNinja", "--log-level=VERBOSE", "-DIREE_BUILD_PYTHON_BINDINGS=ON", - "-DIREE_BUILD_SAMPLES=OFF", - "-DIREE_BUILD_TESTS=OFF", # Disable .so.0 style symlinking. Python wheels don't preserve links, # so this ~doubles the binary size if not disabled (yikes!). "-DCMAKE_PLATFORM_NO_VERSIONED_SONAME=ON", + "-DIREE_BUILD_TESTS=OFF", + "-DIREE_BUILD_SAMPLES=OFF", "-DPython3_EXECUTABLE={}".format(sys.executable), "-DCMAKE_BUILD_TYPE={}".format(cfg), # TODO(scotttodd): include IREE_TARGET_BACKEND_WEBGPU here (and in env) diff --git a/runtime/setup.py b/runtime/setup.py index 345f31fc9291..b7854c02643f 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -275,7 +275,7 @@ def build_configuration(cmake_build_dir, cmake_install_dir, extra_cmake_args=()) "OFF" if platform.system() == "Darwin" else "ON", ), get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", - "" if platform.system() != "Linux" else "rocm"), + "" if sysconfig.get_platform() != "linux-x86_64" else "rocm;level_zero"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), ] + list(extra_cmake_args) add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER") From 12900681f09567a7523f1dee1da2709479fa1a80 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Sat, 12 Aug 2023 12:58:04 -0700 Subject: [PATCH 21/38] Remove dependency of iree.runtime to iree.runtime.distributed --- runtime/bindings/python/iree/runtime/__init__.py | 1 - runtime/bindings/python/iree/runtime/distributed/distributed.py | 1 - runtime/bindings/python/iree/runtime/distributed/run_rank.py | 1 - .../python/iree/runtime/distributed/sharding_pass_validation.py | 2 +- 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py index a9201b863d5f..d594b23c88d0 100644 --- a/runtime/bindings/python/iree/runtime/__init__.py +++ b/runtime/bindings/python/iree/runtime/__init__.py @@ -66,5 +66,4 @@ from .io import * from .tracing import * -from . import distributed from . import flags diff --git a/runtime/bindings/python/iree/runtime/distributed/distributed.py b/runtime/bindings/python/iree/runtime/distributed/distributed.py index 258e517b2cf2..31e0a5e13a42 100644 --- a/runtime/bindings/python/iree/runtime/distributed/distributed.py +++ b/runtime/bindings/python/iree/runtime/distributed/distributed.py @@ -4,7 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import iree.compiler import sys import iree.runtime from iree.runtime.array_interop import DeviceArray diff --git a/runtime/bindings/python/iree/runtime/distributed/run_rank.py b/runtime/bindings/python/iree/runtime/distributed/run_rank.py index 7ad00f7256cc..86761d3172b2 100644 --- a/runtime/bindings/python/iree/runtime/distributed/run_rank.py +++ b/runtime/bindings/python/iree/runtime/distributed/run_rank.py @@ -4,7 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import iree.compiler import argparse import iree.runtime from iree.runtime.array_interop import DeviceArray diff --git a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py index 446520451623..599d6604b8a8 100644 --- a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py +++ b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py @@ -1,7 +1,7 @@ import iree.compiler import iree.runtime import os -from iree.runtime.distributed import run_ranks +from .distributed import run_ranks import subprocess from pathlib import Path from jax._src.lib import xla_client From 0ee80fc23e7e69ebf1f99a52a9c7cd785371876e Mon Sep 17 00:00:00 2001 From: powderluv Date: Mon, 14 Aug 2023 22:31:02 -0700 Subject: [PATCH 22/38] Switch windows to self-hosted --- .github/workflows/build_package.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 0d8b588c4043..172f09f9c7ea 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -60,11 +60,11 @@ jobs: # Windows packages. - runs-on: - - ${{ github.repository == 'openxla/iree' && 'windows-2022-64core' || 'windows-2022'}} + - ${{ github.repository == 'openxla/iree' && 'windows-2022-64core' || '7950X'}} build-family: windows build-package: py-compiler-pkg experimental: true - - runs-on: windows-2022 + - runs-on: 7950X build-family: windows build-package: py-runtime-pkg experimental: true @@ -105,13 +105,15 @@ jobs: # OS specific setup ########################################################################## - - name: Install dependencies (Windows) - if: "matrix.build-family == 'windows'" - shell: powershell - run: ./c/build_tools/python_deploy/install_windows_deps.ps1 + #- name: Install dependencies (Windows) + # if: "matrix.build-family == 'windows'" + # shell: powershell + # run: ./c/build_tools/python_deploy/install_windows_deps.ps1 - name: "Configure MSVC (Windows)" if: "matrix.build-family == 'windows'" uses: ilammy/msvc-dev-cmd@7315a94840631165970262a99c72cfb48a65d25d # v1.12.0 + with: + arch: x64 ########################################################################## # Write version_info.json From fd437ada654df1d2a821f5309f2197da1bf0495d Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 25 Aug 2023 03:33:22 -0400 Subject: [PATCH 23/38] Revert "[codegen][spirv] Pack/transpose matrix B for better coop mmma" This reverts commit a6512dc5ba99912eaa24cd0b8a5892db1b38e053. --- .../Flow/Transforms/FormDispatchRegions.cpp | 9 +- .../Dialect/Flow/Transforms/Passes.cpp | 5 - .../compiler/Preprocessing/Common/BUILD.bazel | 3 +- .../Preprocessing/Common/CMakeLists.txt | 1 - .../Common/ConvertLinalgMatmulToMmt.cpp | 119 -------------- .../Common/GeneralizeAndFuse.cpp | 148 ------------------ .../compiler/Preprocessing/Common/Passes.h | 6 - .../compiler/Preprocessing/Common/Passes.td | 12 -- 8 files changed, 2 insertions(+), 301 deletions(-) delete mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp delete mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index cd4a8987d59f..11fba41d9a64 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -662,14 +662,7 @@ isFusableWithProducer(OpOperand &operand, } auto consumerLinalgOp = cast(consumer); - if (consumerLinalgOp.isDpsInput(&operand)) { - // TODO: Add some marker on transpose and MatmulOp to indicate mmt. - bool fuseTransposeAndMatmul = - isa(consumer) && isa(producer); - if (fuseTransposeAndMatmul) { - return true; - } - } else if (!consumerLinalgOp.isDpsInit(&operand)) { + if (!consumerLinalgOp.isDpsInit(&operand)) { return false; } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index ca61270289e9..d094e0864834 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -85,11 +85,6 @@ static llvm::cl::opt clDispatchGenerateWorkloadRegion( llvm::cl::desc("Generate the workload region."), llvm::cl::init(true)); -static llvm::cl::opt clEnableTransposeMatmulLayout( - "iree-flow-enable-transpose-matmul-layout", - llvm::cl::desc("Enable transposing the B matrix for matmuls."), - llvm::cl::init(false)); - static llvm::cl::opt clNormalizeInputIndexingMap( "iree-flow-normalize-input-indexing-map", llvm::cl::desc("Enable normalizing input indexing map to identity."), diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 08795e578b27..7c43380e1f66 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -34,10 +34,9 @@ iree_compiler_cc_library( "ConvertConvNchwToNhwc.cpp", "ConvertConvToChannelsLast.cpp", "ConvertLinalgMatmulToMmt.cpp", - "GeneralizeAndFuse.cpp", "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", - "PadLinalgOps.cpp", + "PadLinalgOps.cpp", "PassDetail.h", "Passes.cpp", ], diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 1b83bc1ea1d7..a147974d907d 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -30,7 +30,6 @@ iree_cc_library( "ConvertConvNchwToNhwc.cpp" "ConvertConvToChannelsLast.cpp" "ConvertLinalgMatmulToMmt.cpp" - "GeneralizeAndFuse.cpp" "GeneralizeConvolutions.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp deleted file mode 100644 index f22e55cf92ac..000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include -#include - -#include "iree/compiler/Preprocessing/Common/PassDetail.h" -#include "iree/compiler/Preprocessing/Common/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -namespace { - -// Converts linalg.matmul to an linalg.transpose + linalg.matmul. -// Such that matrix B layout changes to col major. -class LinalgMatmulOpToLinalgMmtPattern final - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, - PatternRewriter &rewriter) const override { - Location loc = matmulOp.getLoc(); - Value lhs = matmulOp.getDpsInputOperand(0)->get(); - Value rhs = matmulOp.getDpsInputOperand(1)->get(); - Value acc = matmulOp.getDpsInitOperand(0)->get(); - if (dyn_cast(rhs.getDefiningOp())) { - return failure(); - } - auto rhsType = rhs.getType().cast(); - auto rhsShape = rhsType.getShape(); - auto rhsElemType = rhsType.getElementType(); - SmallVector transposedRhsShape = {rhsShape[1], rhsShape[0]}; - - // GenericOp - int64_t nloops = rhsShape.size(); - AffineExpr mDim, nDim; - bindDims(getContext(), mDim, nDim); - auto inputMap = AffineMap::get(2, 0, {mDim, nDim}, getContext()); - auto packedMap = AffineMap::get(2, 0, {nDim, mDim}, getContext()); - SmallVector indexingMaps = {inputMap, packedMap}; - - Value transposedRhs = - rewriter.create(loc, transposedRhsShape, rhsElemType); - SmallVector loopAttributeTypes( - nloops, utils::IteratorType::parallel); - - Value packedRhs = - rewriter - .create( - loc, transposedRhs.getType(), - /*inputs=*/rhs, /*outputs=*/transposedRhs, indexingMaps, - loopAttributeTypes, - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); - }) - .getResult(0); - - // TransposeOp - Value initOp = rewriter.create(loc, rhsShape, rhsElemType); - SmallVector transposedPerm = {1, 0}; - Value transposePackedRhs = - rewriter - .create(loc, packedRhs, initOp, transposedPerm) - .getResults()[0]; - - // MatmulOp - Value packedMatmul = - rewriter - .create(loc, matmulOp.getResult(0).getType(), - ArrayRef{lhs, transposePackedRhs}, - ArrayRef{acc}) - .getResult(0); - rewriter.replaceOp(matmulOp, packedMatmul); - return success(); - } -}; - -struct ConvertLinalgMatmulToMmtPass - : public ConvertLinalgMatmulToMmtBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - // Main pattern. - { - RewritePatternSet patterns(&getContext()); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - } -}; -} // namespace - -std::unique_ptr createConvertLinalgMatmulToMmtPass() { - return std::make_unique(); -} - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp deleted file mode 100644 index 8a4e7bee31e3..000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include -#include - -#include "iree/compiler/Preprocessing/Common/PassDetail.h" -#include "iree/compiler/Preprocessing/Common/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -namespace { - -//===----------------------------------------------------------------------===// -// Utility Functions -//===----------------------------------------------------------------------===// - -bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { - return linalg::isaContractionOpInterface(linalgOp) && - llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops()); -} - -//===----------------------------------------------------------------------===// -// Generalize and fusion patterns. -//===----------------------------------------------------------------------===// - -struct GeneralizeAndFusePass - : public GeneralizeAndFuseBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - template - class GeneralizeTargetNamedOpPattern final - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LinalgOpType linalgOp, - PatternRewriter &rewriter) const override { - // TODO: Check consumer is transposeOp. - // TODO: Generalize transpos - FailureOr genericOp = - linalg::generalizeNamedOp(rewriter, linalgOp); - if (failed(genericOp)) return failure(); - return success(); - } - }; - - class FuseMatmulAndTranspose final - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - // Inspo: - // https://github.com/llvm/llvm-project/blob/4f1c12425179608298dc39f5524ba2612609b5e4/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp - LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, - PatternRewriter &rewriter) const override { - const unsigned rhsId = 1; - if (!isMatmulOrBatchMatmul(linalgOp)) return failure(); - Value rhs = linalgOp.getDpsInputOperand(rhsId)->get(); - auto transposeOp = dyn_cast(rhs.getDefiningOp()); - if (!transposeOp) return failure(); - auto perm = transposeOp.getPermutation(); - auto indexingMaps = linalgOp.getIndexingMaps(); - auto rhsMap = indexingMaps[rhsId].cast().getValue(); - int64_t rank = perm.size(); - if (rhsMap.getNumResults() != rank) return failure(); - SmallVector exprs; - for (auto dim_id : perm) { - exprs.push_back(rhsMap.getResult(dim_id)); - } - AffineMap transposedRhsMap = - AffineMap::get(rhsMap.getNumDims(), 0, exprs, getContext()); - - // TODO: Fold transposeOp as transposed indexing for matmulOp. - // Generate a map set. - auto lhsMap = indexingMaps[0].cast().getValue(); - auto accMap = indexingMaps[2].cast().getValue(); - SmallVector newIndexingMaps = {lhsMap, transposedRhsMap, - accMap}; - - // Generate new list of args. - Value newRhs = transposeOp.getDpsInputOperand(0)->get(); - Value lhs = linalgOp.getDpsInputOperand(0)->get(); - Value acc = linalgOp.getDpsInitOperand(0)->get(); - SmallVector inputs = {lhs, newRhs}; - - // Generate a new genericOp. - linalg::GenericOp genericOp = rewriter.create( - linalgOp.getLoc(), linalgOp.getResultTypes(), /*inputs*/ inputs, - /*outputs*/ acc, newIndexingMaps, linalgOp.getIteratorTypesArray()); - // Block consumerBlock = linalgOp->getRegion(0).front(); - // genericOp.getRegion().push_back(consumerBlock); - // llvm::outs()<<"new op - // regions:"<getNumRegions()<<"\n"; - // llvm::outs()<<"new op - // regions:"<getNumRegions()<<"\n"; - // llvm::outs()<<"new op - // blocks:"<getNumRegions()<<"\n"; - // llvm::outs()<<"old op - // blocks:"<getRegion(0), genericOp.getRegion(), - genericOp.getRegion().begin()); - rewriter.replaceOp(linalgOp, genericOp->getResults()); - return success(); - } - }; - - void runOnOperation() override { - MLIRContext *context = &getContext(); - // Main pattern. - // Generalize + Fuse pattern. - { - RewritePatternSet patterns(&getContext()); - patterns.insert, - FuseMatmulAndTranspose>(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - } -}; -} // namespace - -std::unique_ptr createGeneralizeAndFusePass() { - return std::make_unique(); -} - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index 662bdf59d9cf..b6158e0caab2 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -24,12 +24,6 @@ std::unique_ptr createConvertConv2DToImg2ColPass(); std::unique_ptr> createConvertConvNchwToNhwcPass(); -// Pass to convert a linalg.matmul into linalg.transpose + linalg.matmul. -std::unique_ptr createConvertLinalgMatmulToMmtPass(); - -// Generalizes named op and try to fuse them -std::unique_ptr createGeneralizeAndFusePass(); - /// Moves the body of the entire function into a single dispatch. std::unique_ptr> createMakeSingleDispatchForFunctionPass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index db0d89940ef7..a2fb9328965f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -22,18 +22,6 @@ def ConvertConvNchwToNhwc : "mlir::iree_compiler::IREE::createConvertConvNchwToNhwcPass()"; } -def ConvertLinalgMatmulToMmt : - Pass<"iree-flow-convert-linalg-matmul-to-mmt", ""> { - let summary = "Convert linalg.matmul to linalg.transpose + linalg.matmul"; - let constructor = "mlir::iree_compiler::IREE::createConvertLinalgMatmulToMmtPass()"; -} - -def GeneralizeAndFuse : - Pass<"iree-flow-generalize-and-fuse", ""> { - let summary = "Generalizes named op and try to fuse them."; - let constructor = "mlir::iree_compiler::IREE::createGeneralizeAndFusePass()"; -} - def MakeSingleDispatchForFunction : Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> { let summary = "Convert entire function into a single dispatch"; From bf901a66b950fc5299ce8d39b678a71314b33508 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 25 Aug 2023 13:45:20 +0000 Subject: [PATCH 24/38] Remove ConvertLinalgMatmulToMmt completely --- compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel | 1 - compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 7c43380e1f66..f33d8712093b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -33,7 +33,6 @@ iree_compiler_cc_library( "ConvertConv2DToImg2Col.cpp", "ConvertConvNchwToNhwc.cpp", "ConvertConvToChannelsLast.cpp", - "ConvertLinalgMatmulToMmt.cpp", "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index a147974d907d..09e61373aed9 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -29,7 +29,6 @@ iree_cc_library( "ConvertConv2DToImg2Col.cpp" "ConvertConvNchwToNhwc.cpp" "ConvertConvToChannelsLast.cpp" - "ConvertLinalgMatmulToMmt.cpp" "GeneralizeConvolutions.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" From 579357b720612c2a1c69d2c28d9d62212f0bf954 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 24 Aug 2023 01:11:37 -0500 Subject: [PATCH 25/38] Add hip headers to build ROCm backend without the SDK. --- build_tools/python_deploy/build_linux_packages.sh | 2 ++ build_tools/python_deploy/build_windows_packages.ps1 | 2 ++ 2 files changed, 4 insertions(+) diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 3363d35b421c..c2739fc90585 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -157,10 +157,12 @@ function build_iree_runtime() { export IREE_RUNTIME_BUILD_TRACY=ON # We install the needed build deps below for the tools. export IREE_RUNTIME_BUILD_TRACY_TOOLS=ON + export IREE_EXTERNAL_HAL_DRIVERS="rocm" build_wheel runtime/ } function build_iree_compiler() { + export IREE_TARGET_BACKEND_ROCM=ON build_wheel compiler/ } diff --git a/build_tools/python_deploy/build_windows_packages.ps1 b/build_tools/python_deploy/build_windows_packages.ps1 index 43906f8800f7..a8fbd3a9aadf 100644 --- a/build_tools/python_deploy/build_windows_packages.ps1 +++ b/build_tools/python_deploy/build_windows_packages.ps1 @@ -67,11 +67,13 @@ function run() { function build_iree_runtime() { param($python_version) $env:IREE_HAL_DRIVER_VULKAN = "ON" + $env:IREE_EXTERNAL_HAL_DRIVERS = "rocm" & py -${python_version} -m pip wheel -v -w $output_dir $repo_root/runtime/ } function build_iree_compiler() { param($python_version) + $env:IREE_TARGET_BACKEND_ROCM= "ON" py -${python_version} -m pip wheel -v -w $output_dir $repo_root/compiler/ } From faaec8f3482b92240d0e900019a8a934e0f38560 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Thu, 14 Sep 2023 00:59:20 -0700 Subject: [PATCH 26/38] [experimental][ROCM] Stream Command Buffer --- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 4 +- experimental/rocm/CMakeLists.txt | 5 + experimental/rocm/api.h | 41 ++ experimental/rocm/dynamic_symbol_tables.h | 1 + .../rocm/registration/driver_module.c | 23 +- experimental/rocm/rocm_device.c | 92 ++- experimental/rocm/rocm_device.h | 1 + experimental/rocm/rocm_driver.c | 10 +- experimental/rocm/stream_command_buffer.c | 562 ++++++++++++++++++ experimental/rocm/stream_command_buffer.h | 49 ++ 10 files changed, 772 insertions(+), 16 deletions(-) create mode 100644 experimental/rocm/stream_command_buffer.c create mode 100644 experimental/rocm/stream_command_buffer.h diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index e529e8a1017a..c0bbef5614fd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -501,7 +501,7 @@ static void addLowerAndOptimzeAddressComputation(OpPassManager &pm) { pm.addPass(createExtractAddressComputationGPUPass()); pm.addNestedPass(memref::createExpandOpsPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createIREEExpandStridedMetadataPass()); // Hoist loop invariant variables to give decompose affine pass the right loop // dependencies. pm.addPass(createLoopInvariantCodeMotionPass()); @@ -575,7 +575,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool useROCM) { pm.addNestedPass(memref::createExpandOpsPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createIREEExpandStridedMetadataPass()); pm.addPass(createEmulateNarrowTypePass()); pm.addPass(createLowerAffinePass()); pm.addPass(createCanonicalizerPass()); diff --git a/experimental/rocm/CMakeLists.txt b/experimental/rocm/CMakeLists.txt index 7a48dabe8fda..9cbc6ee14633 100644 --- a/experimental/rocm/CMakeLists.txt +++ b/experimental/rocm/CMakeLists.txt @@ -61,6 +61,8 @@ iree_cc_library( "pipeline_layout.h" "status_util.c" "status_util.h" + "stream_command_buffer.c" + "stream_command_buffer.h" "tracing.c" "tracing.h" INCLUDES @@ -75,8 +77,11 @@ iree_cc_library( iree::base::internal::flatcc::parsing iree::base::internal::synchronization iree::hal + iree::hal::utils::collective_batch + iree::hal::utils::deferred_command_buffer iree::hal::utils::file_transfer iree::hal::utils::memory_file + iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::schemas::rocm_executable_def_c_fbs COPTS diff --git a/experimental/rocm/api.h b/experimental/rocm/api.h index 68fa1913bf2f..7949ac407afa 100644 --- a/experimental/rocm/api.h +++ b/experimental/rocm/api.h @@ -16,6 +16,46 @@ extern "C" { #endif // __cplusplus +//===----------------------------------------------------------------------===// +// iree_hal_rocm_device_t +//===----------------------------------------------------------------------===// + +// Defines how command buffers are recorded and executed. +typedef enum iree_hal_rocm_command_buffer_mode_e { + // Command buffers are recorded into ROCM null stream. + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT = 0, + // Command buffers are directly issued against ROCM stream. + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM = 1, +} iree_hal_rocm_command_buffer_mode_t; + +// Parameters configuring an iree_hal_rocm_device_t. +// Must be initialized with iree_hal_rocm_device_params_initialize prior to use. +typedef struct iree_hal_rocm_device_params_t { + + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; + + // Specifies how command buffers are recorded and executed. + iree_hal_rocm_command_buffer_mode_t command_buffer_mode; + + // Enables tracing of command buffers when IREE tracing is enabled. + // May take advantage of additional extensions for more accurate timing or + // hardware-specific performance counters. + // + // NOTE: tracing has a non-trivial overhead and will skew the timing of + // submissions and introduce false barriers between dispatches. Use this to + // identify slow dispatches and refine from there; be wary of whole-program + // tracing with this enabled. + bool stream_tracing; + +} iree_hal_rocm_device_params_t; + +// Initializes |out_params| to default values. +IREE_API_EXPORT void iree_hal_rocm_device_params_initialize( + iree_hal_rocm_device_params_t* out_params); + //===----------------------------------------------------------------------===// // iree_hal_rocm_driver_t //===----------------------------------------------------------------------===// @@ -35,6 +75,7 @@ IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize( // |out_driver| must be released by the caller (see |iree_hal_driver_release|). IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t *options, iree_allocator_t host_allocator, iree_hal_driver_t **out_driver); diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index 785f0edc9ea3..b0ee67dba5a8 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -31,6 +31,7 @@ RC_PFN_DECL(hipMemsetD32Async, void *, int, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD16Async, void *, short, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD8Async, void *, char, size_t, hipStream_t) RC_PFN_DECL(hipMemcpy, void *, const void *, size_t, hipMemcpyKind) +RC_PFN_DECL(hipMemcpyHtoDAsync, hipDeviceptr_t, void *, size_t, hipStream_t) RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind, hipStream_t) RC_PFN_DECL(hipMalloc, void **, size_t) diff --git a/experimental/rocm/registration/driver_module.c b/experimental/rocm/registration/driver_module.c index fcdadfe3c112..f1e180a91803 100644 --- a/experimental/rocm/registration/driver_module.c +++ b/experimental/rocm/registration/driver_module.c @@ -11,6 +11,19 @@ #include "experimental/rocm/api.h" #include "iree/base/api.h" +#include "iree/base/internal/flags.h" + +// Force using ROCM streams until we support command buffer caching to avoid the +// overhead of graph creation. +IREE_FLAG( + bool, rocm_use_streams, true, + "Use ROCM streams for executing command buffers (instead of graphs)."); + +IREE_FLAG( + bool, rocm_tracing, true, + "Enables tracing of stream events when Tracy instrumentation is enabled.\n" + "Severely impacts benchmark timings and should only be used when\n" + "analyzing dispatch timings."); static iree_status_t iree_hal_rocm_driver_factory_enumerate( void *self, iree_host_size_t *out_driver_info_count, @@ -36,10 +49,18 @@ static iree_status_t iree_hal_rocm_driver_factory_try_create( (int)driver_name.size, driver_name.data); } IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_rocm_device_params_t default_params; + iree_hal_rocm_device_params_initialize(&default_params); + if (FLAG_rocm_use_streams) { + default_params.command_buffer_mode = + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM; + } + default_params.stream_tracing = FLAG_rocm_tracing; iree_hal_rocm_driver_options_t driver_options; iree_hal_rocm_driver_options_initialize(&driver_options); iree_status_t status = iree_hal_rocm_driver_create( - driver_name, &driver_options, host_allocator, out_driver); + driver_name, &default_params, &driver_options, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; } diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index b76f165efa5c..9e508d6b8907 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c @@ -19,8 +19,10 @@ #include "experimental/rocm/rocm_allocator.h" #include "experimental/rocm/rocm_event.h" #include "experimental/rocm/status_util.h" +#include "experimental/rocm//stream_command_buffer.h" #include "experimental/rocm/tracing.h" #include "iree/base/internal/arena.h" +#include "iree/hal/utils/deferred_command_buffer.h" #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" @@ -40,6 +42,9 @@ typedef struct iree_hal_rocm_device_t { // to ensure the symbols remains valid. iree_hal_driver_t* driver; + // Parameters used to control device behavior. + iree_hal_rocm_device_params_t params; + hipDevice_t device; // TODO: support multiple streams. @@ -50,6 +55,10 @@ typedef struct iree_hal_rocm_device_t { // Optional provider used for creating/configuring collective channels. iree_hal_channel_provider_t* channel_provider; + + // Cache of the direct stream command buffer initialized when in stream mode. + // TODO: have one cached per stream once there are multiple streams. + iree_hal_command_buffer_t* stream_command_buffer; } iree_hal_rocm_device_t; static const iree_hal_device_vtable_t iree_hal_rocm_device_vtable; @@ -60,11 +69,21 @@ static iree_hal_rocm_device_t* iree_hal_rocm_device_cast( return (iree_hal_rocm_device_t*)base_value; } +IREE_API_EXPORT void iree_hal_rocm_device_params_initialize( + iree_hal_rocm_device_params_t* out_params) { + memset(out_params, 0, sizeof(*out_params)); + out_params->arena_block_size = 32*1024; + out_params->command_buffer_mode = IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT; + out_params->stream_tracing = false; +} + static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_command_buffer_release(device->stream_command_buffer); + // There should be no more buffers live that use the allocator. iree_hal_allocator_release(device->device_allocator); @@ -75,6 +94,8 @@ static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { ROCM_IGNORE_ERROR(device->context_wrapper.syms, hipStreamDestroy(device->stream)); + iree_arena_block_pool_deinitialize(&device->block_pool); + // Finally, destroy the device. iree_hal_driver_release(device->driver); @@ -85,9 +106,9 @@ static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { static iree_status_t iree_hal_rocm_device_create_internal( iree_hal_driver_t* driver, iree_string_view_t identifier, - hipDevice_t rocm_device, hipStream_t stream, hipCtx_t context, - iree_hal_rocm_dynamic_symbols_t* syms, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { + const iree_hal_rocm_device_params_t* params, hipDevice_t rocm_device, + hipStream_t stream, hipCtx_t context, iree_hal_rocm_dynamic_symbols_t* syms, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_hal_rocm_device_t* device = NULL; iree_host_size_t total_size = sizeof(*device) + identifier.size; IREE_RETURN_IF_ERROR( @@ -99,20 +120,36 @@ static iree_status_t iree_hal_rocm_device_create_internal( uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); buffer_ptr += iree_string_view_append_to_buffer( identifier, &device->identifier, (char*)buffer_ptr); + device->params = *params; device->device = rocm_device; device->stream = stream; device->context_wrapper.rocm_context = context; device->context_wrapper.rocm_device = rocm_device; device->context_wrapper.host_allocator = host_allocator; + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->block_pool); device->context_wrapper.syms = syms; // Enable tracing for the (currently only) stream - no-op if disabled. - iree_status_t status = iree_hal_rocm_tracing_context_allocate( + iree_status_t status = iree_ok_status(); + if (device->params.stream_tracing) { + status = iree_hal_rocm_tracing_context_allocate( &device->context_wrapper, device->identifier, stream, &device->block_pool, host_allocator, &device->tracing_context); + } if (iree_status_is_ok(status)) { status = iree_hal_rocm_allocator_create(&device->context_wrapper, &device->device_allocator); } + if (iree_status_is_ok(status) && + params->command_buffer_mode == IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM) { + status = iree_hal_rocm_stream_command_buffer_create( + (iree_hal_device_t*)device, &device->context_wrapper, + device->tracing_context, + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION | + IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED, + IREE_HAL_COMMAND_CATEGORY_ANY, /*binding_capacity=*/0, device->stream, + &device->block_pool, &device->stream_command_buffer); + } if (iree_status_is_ok(status)) { *out_device = (iree_hal_device_t*)device; } else { @@ -123,10 +160,12 @@ static iree_status_t iree_hal_rocm_device_create_internal( iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* params, iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(params); IREE_TRACE_ZONE_BEGIN(z0); hipCtx_t context; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -140,8 +179,8 @@ iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, syms, hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); if (iree_status_is_ok(status)) { - status = iree_hal_rocm_device_create_internal(driver, identifier, device, - stream, context, syms, + status = iree_hal_rocm_device_create_internal(driver, identifier, params, + device, stream, context, syms, host_allocator, out_device); } if (!iree_status_is_ok(status)) { @@ -228,10 +267,21 @@ static iree_status_t iree_hal_rocm_device_create_command_buffer( iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); - return iree_hal_rocm_direct_command_buffer_create( - base_device, &device->context_wrapper, device->tracing_context, mode, - command_categories, queue_affinity, binding_capacity, &device->block_pool, - out_command_buffer); + switch (device->params.command_buffer_mode) { + case IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT: + return iree_hal_rocm_direct_command_buffer_create( + base_device, &device->context_wrapper, device->tracing_context, mode, + command_categories, queue_affinity, binding_capacity, &device->block_pool, + out_command_buffer); + case IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM: + return iree_hal_deferred_command_buffer_create( + base_device, mode, command_categories, binding_capacity, + &device->block_pool, iree_hal_device_host_allocator(base_device), + out_command_buffer); + default: + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid command buffer mode"); + } } static iree_status_t iree_hal_rocm_device_create_descriptor_set_layout( @@ -383,8 +433,28 @@ static iree_status_t iree_hal_rocm_device_queue_execute( // synchronizes after every submit. // TODO(raikonenfnu): currently run on default/null stream, when cmd buffer // stream work with device->stream, we'll change + for (iree_host_size_t i = 0; i < command_buffer_count; i++) { + iree_hal_command_buffer_t* command_buffer = command_buffers[i]; + if (iree_hal_rocm_stream_command_buffer_isa(command_buffer)) { + // Nothing to do for an inline command buffer; all the work has already + // been submitted. When we support semaphores we'll still need to signal + // their completion but do not have to worry about any waits: if there + // were waits we wouldn't have been able to execute inline! + } else if (iree_hal_rocm_direct_command_buffer_isa(command_buffer)) { + IREE_TRACE_ZONE_BEGIN_NAMED(z0, "hipStreamSynchronize"); + ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(0), + "hipStreamSynchronize"); + iree_hal_rocm_tracing_context_collect(device->tracing_context); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } else { + IREE_RETURN_IF_ERROR(iree_hal_deferred_command_buffer_apply( + command_buffers[i], device->stream_command_buffer, + iree_hal_buffer_binding_table_empty())); + } + } IREE_TRACE_ZONE_BEGIN_NAMED(z0, "hipStreamSynchronize"); - ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(0), + ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(device->stream), "hipStreamSynchronize"); iree_hal_rocm_tracing_context_collect(device->tracing_context); IREE_TRACE_ZONE_END(z0); diff --git a/experimental/rocm/rocm_device.h b/experimental/rocm/rocm_device.h index 083f4c7cddb6..7abd4e67ce36 100644 --- a/experimental/rocm/rocm_device.h +++ b/experimental/rocm/rocm_device.h @@ -19,6 +19,7 @@ extern "C" { // Creates a device that owns and manages its own hipContext. iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* params, iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, diff --git a/experimental/rocm/rocm_driver.c b/experimental/rocm/rocm_driver.c index bcec506e2f86..5b67fdc2f3a0 100644 --- a/experimental/rocm/rocm_driver.c +++ b/experimental/rocm/rocm_driver.c @@ -21,6 +21,7 @@ typedef struct iree_hal_rocm_driver_t { // We allow overriding so that multiple ROCM versions can be exposed in the // same process. iree_string_view_t identifier; + iree_hal_rocm_device_params_t default_params; int default_device_index; // ROCM symbols. iree_hal_rocm_dynamic_symbols_t syms; @@ -49,6 +50,7 @@ IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize( static iree_status_t iree_hal_rocm_driver_create_internal( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { iree_hal_rocm_driver_t* driver = NULL; @@ -60,6 +62,8 @@ static iree_status_t iree_hal_rocm_driver_create_internal( iree_string_view_append_to_buffer( identifier, &driver->identifier, (char*)driver + total_size - identifier.size); + memcpy(&driver->default_params, default_params, + sizeof(driver->default_params)); driver->default_device_index = options->default_device_index; iree_status_t status = iree_hal_rocm_dynamic_symbols_initialize(host_allocator, &driver->syms); @@ -84,14 +88,16 @@ static void iree_hal_rocm_driver_destroy(iree_hal_driver_t* base_driver) { IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(default_params); IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_rocm_driver_create_internal( - identifier, options, host_allocator, out_driver); + identifier, default_params, options, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; @@ -286,7 +292,7 @@ static iree_status_t iree_hal_rocm_driver_create_device_by_id( // Attempt to create the device. iree_status_t status = - iree_hal_rocm_device_create(base_driver, device_name, &driver->syms, + iree_hal_rocm_device_create(base_driver, device_name, &driver->default_params, &driver->syms, device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); diff --git a/experimental/rocm/stream_command_buffer.c b/experimental/rocm/stream_command_buffer.c new file mode 100644 index 000000000000..f6d98df63275 --- /dev/null +++ b/experimental/rocm/stream_command_buffer.c @@ -0,0 +1,562 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/rocm/stream_command_buffer.h" + +#include "experimental/rocm/rocm_buffer.h" +#include "experimental/rocm/rocm_event.h" +#include "experimental/rocm/native_executable.h" +#include "experimental/rocm/pipeline_layout.h" +#include "experimental/rocm/status_util.h" +#include "iree/hal/utils/collective_batch.h" +#include "iree/hal/utils/resource_set.h" + +#define IREE_HAL_ROCM_MAX_BINDING_COUNT 64 +// Kernel arguments contains binding and push constants. +#define IREE_HAL_ROCM_MAX_KERNEL_ARG 128 + +typedef struct { + iree_hal_command_buffer_t base; + iree_hal_rocm_context_wrapper_t* context; + iree_hal_rocm_tracing_context_t* tracing_context; + hipStream_t stream; + + // Maintains a reference to all resources used within the command buffer. + // Reset on each begin. + iree_hal_resource_set_t* resource_set; + + // Staging arena used for host->device transfers. + // Used for when we need ROCM to be able to reference memory as it performs + // asynchronous operations. + iree_arena_allocator_t arena; + + // Iteratively constructed batch of collective operations. + iree_hal_collective_batch_t collective_batch; + + int32_t push_constant[IREE_HAL_ROCM_MAX_PUSH_CONSTANT_COUNT]; + + // Keep track of the current set of kernel arguments. + void* current_descriptor[IREE_HAL_ROCM_MAX_KERNEL_ARG]; + hipDeviceptr_t* device_ptrs[IREE_HAL_ROCM_MAX_KERNEL_ARG]; +} iree_hal_rocm_stream_command_buffer_t; + +static const iree_hal_command_buffer_vtable_t + iree_hal_rocm_stream_command_buffer_vtable; + +static iree_hal_rocm_stream_command_buffer_t* +iree_hal_rocm_stream_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_stream_command_buffer_vtable); + return (iree_hal_rocm_stream_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_rocm_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_tracing_context_t* tracing_context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hipStream_t stream, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + if (binding_capacity > 0) { + // TODO(#10144): support indirect command buffers with binding tables. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_rocm_stream_command_buffer_t* command_buffer = NULL; + iree_status_t status = + iree_allocator_malloc(context->host_allocator, sizeof(*command_buffer), + (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_command_buffer_initialize( + device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + binding_capacity, &iree_hal_rocm_stream_command_buffer_vtable, + &command_buffer->base); + command_buffer->context = context; + command_buffer->tracing_context = tracing_context; + command_buffer->stream = stream; + iree_arena_initialize(block_pool, &command_buffer->arena); + for (size_t i = 0; i < IREE_HAL_ROCM_MAX_KERNEL_ARG; i++) { + command_buffer->current_descriptor[i] = &command_buffer->device_ptrs[i]; + } + + status = iree_hal_resource_set_allocate(block_pool, + &command_buffer->resource_set); + } + if (iree_status_is_ok(status)) { + iree_hal_collective_batch_initialize(&command_buffer->arena, + command_buffer->resource_set, + &command_buffer->collective_batch); + } + + *out_command_buffer = &command_buffer->base; + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_rocm_stream_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); + iree_hal_resource_set_free(command_buffer->resource_set); + iree_arena_deinitialize(&command_buffer->arena); + iree_allocator_free(command_buffer->context->host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +bool iree_hal_rocm_stream_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_resource_is(&command_buffer->resource, + &iree_hal_rocm_stream_command_buffer_vtable); +} + +// Flushes any pending batched collective operations. +// Must be called before any other non-collective nodes are added to the graph +// or a barrier is encountered. +static iree_status_t iree_hal_rocm_stream_command_buffer_flush_collectives( + iree_hal_rocm_stream_command_buffer_t* command_buffer) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Collectives not implemented on ROCM"); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, + /*file_name=*/NULL, 0, + /*line=*/0, /*func_name=*/NULL, 0, "iree_hal_rocm_stream_command_buffer", + strlen("iree_hal_rocm_stream_command_buffer")); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Reset the arena as there should be nothing using it now that we've + // dispatched all our operations inline. + // NOTE: the resource set may contain resources we need to drop as we don't + // need to keep them live any longer than it takes to schedule the + // operations. In a real command buffer we would be this stream command + // buffer is strictly used to perform inline execution/replay of + // deferred command buffers that are retaining the resources already. + // NOTE: reseting the arena invalidates the collective batch. + iree_arena_reset(&command_buffer->arena); + iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); + iree_hal_resource_set_free(command_buffer->resource_set); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_allocate( + command_buffer->arena.block_pool, &command_buffer->resource_set)); + iree_hal_collective_batch_initialize(&command_buffer->arena, + command_buffer->resource_set, + &command_buffer->collective_batch); + + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); + + return iree_ok_status(); +} + +static void iree_hal_rocm_stream_command_buffer_begin_debug_group( + iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, + iree_hal_label_color_t label_color, + const iree_hal_label_location_t* location) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, + location ? location->file.data : NULL, location ? location->file.size : 0, + location ? location->line : 0, /*func_name=*/NULL, 0, label.data, + label.size); + + // TODO: pass along to CUPTI if available. +} + +static void iree_hal_rocm_stream_command_buffer_end_debug_group( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + // TODO: pass along to CUPTI if available. + + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + // We could mark the memory as invalidated so that if managed ROCM does not + // try to copy it back to the host. + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + hipDeviceptr_t dst = + (hipDeviceptr_t)((uintptr_t)target_device_buffer + target_offset); + + size_t num_elements = length / pattern_length; + switch (pattern_length) { + case 4: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD32Async(dst, *(const uint32_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD32Async"); + break; + } + case 2: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD16Async(dst, *(const uint16_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD16Async"); + break; + } + case 1: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD8Async(dst, *(const uint8_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD8Async"); + break; + } + default: + return iree_make_status(IREE_STATUS_INTERNAL, + "unsupported fill pattern length"); + } + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Allocate scratch space in the arena for the data and copy it in. + // The update buffer API requires that the command buffer capture the host + // memory at the time the method is called in case the caller wants to reuse + // the memory. Because ROCM memcpys are async if we didn't copy it's possible + // for the reused memory to change before the stream reaches the copy + // operation and get the wrong data. + const uint8_t* src = (const uint8_t*)source_buffer + source_offset; + if (command_buffer->arena.block_pool) { + uint8_t* storage = NULL; + IREE_RETURN_IF_ERROR( + iree_arena_allocate(&command_buffer->arena, length, (void**)&storage)); + memcpy(storage, src, length); + src = storage; + } + + // Issue the copy using the scratch memory as the source. + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + hipDeviceptr_t dst = (hipDeviceptr_t)((uintptr_t)target_device_buffer + + iree_hal_buffer_byte_offset(target_buffer) + target_offset); + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemcpyHtoDAsync(dst, (void*)src, length, command_buffer->stream), + "hipMemcpyHtoDAsync"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + hipDeviceptr_t source_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(source_buffer)); + source_offset += iree_hal_buffer_byte_offset(source_buffer); + hipDeviceptr_t dst = (hipDeviceptr_t)((uintptr_t)target_device_buffer + target_offset); + hipDeviceptr_t src = (hipDeviceptr_t)((uintptr_t)source_device_buffer + source_offset); + ROCM_RETURN_IF_ERROR(command_buffer->context->syms, + hipMemcpyAsync(dst, src, length, hipMemcpyDeviceToDevice, command_buffer->stream), + "hipMemcpyAsync"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_collective( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, + iree_hal_collective_op_t op, uint32_t param, + iree_hal_buffer_binding_t send_binding, + iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + return iree_hal_collective_batch_append(&command_buffer->collective_batch, + channel, op, param, send_binding, + recv_binding, element_count); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + + iree_host_size_t constant_base_index = offset / sizeof(int32_t); + for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { + command_buffer->push_constant[i + constant_base_index] = + ((uint32_t*)values)[i]; + } + + return iree_ok_status(); +} + +// Tie together the binding index and its index in |bindings| array. +typedef struct { + uint32_t index; + uint32_t binding; +} iree_hal_rocm_binding_mapping_t; + +// Helper to sort the binding based on their binding index. +static int compare_binding_index(const void* a, const void* b) { + const iree_hal_rocm_binding_mapping_t buffer_a = + *(const iree_hal_rocm_binding_mapping_t*)a; + const iree_hal_rocm_binding_mapping_t buffer_b = + *(const iree_hal_rocm_binding_mapping_t*)b; + return buffer_a.binding < buffer_b.binding ? -1 : 1; +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + + iree_host_size_t base_binding = + iree_hal_rocm_base_binding_index(pipeline_layout, set); + + // Convention with the compiler side. We map bindings to kernel argument. + // We compact the bindings to get a dense set of arguments and keep them order + // based on the binding index. + // Sort the binding based on the binding index and map the array index to the + // argument index. + iree_hal_rocm_binding_mapping_t binding_used[IREE_HAL_ROCM_MAX_BINDING_COUNT]; + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_rocm_binding_mapping_t buffer = {i, bindings[i].binding}; + binding_used[i] = buffer; + } + // TODO: remove this sort - it's thankfully small (1-8 on average) but we + // should be able to avoid it like we do on the CPU side with a bitmap. + qsort(binding_used, binding_count, sizeof(iree_hal_rocm_binding_mapping_t), + compare_binding_index); + assert(binding_count < IREE_HAL_ROCM_MAX_BINDING_COUNT && + "binding count larger than the max expected."); + + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index]; + hipDeviceptr_t device_ptr = + binding.buffer + ? (hipDeviceptr_t)((uintptr_t)iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding.buffer)) + + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset) + : 0; + *((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) = + device_ptr; + } + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_rocm_kernel_params_t kernel_params; + IREE_RETURN_IF_ERROR( + iree_hal_rocm_native_executable_entry_point_kernel_params( + executable, entry_point, &kernel_params)); + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, kernel_params.function_name.data, + kernel_params.function_name.size, + /*line=*/0, /*func_name=*/NULL, 0, kernel_params.function_name.data, + kernel_params.function_name.size); + + // Patch the push constants in the kernel arguments. + iree_host_size_t num_constants = + iree_hal_rocm_pipeline_layout_num_constants(kernel_params.layout); + iree_host_size_t constant_base_index = + iree_hal_rocm_push_constant_index(kernel_params.layout); + for (iree_host_size_t i = 0; i < num_constants; i++) { + *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) = + command_buffer->push_constant[i]; + } + + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipModuleLaunchKernel(kernel_params.function, workgroup_x, workgroup_y, + workgroup_z, kernel_params.block_size[0], + kernel_params.block_size[1], kernel_params.block_size[2], + kernel_params.shared_memory_size, command_buffer->stream, + command_buffer->current_descriptor, NULL), + "hipModuleLaunchKernel"); + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need rocm implementation of dispatch indirect"); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_execute_commands( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_command_buffer_t* base_commands, + iree_hal_buffer_binding_table_t binding_table) { + // TODO(#10144): support indirect command buffers with deferred command + // buffers or graphs. We likely just want to switch to graphs. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); +} + +static const iree_hal_command_buffer_vtable_t + iree_hal_rocm_stream_command_buffer_vtable = { + .destroy = iree_hal_rocm_stream_command_buffer_destroy, + .begin = iree_hal_rocm_stream_command_buffer_begin, + .end = iree_hal_rocm_stream_command_buffer_end, + .begin_debug_group = + iree_hal_rocm_stream_command_buffer_begin_debug_group, + .end_debug_group = iree_hal_rocm_stream_command_buffer_end_debug_group, + .execution_barrier = + iree_hal_rocm_stream_command_buffer_execution_barrier, + .signal_event = iree_hal_rocm_stream_command_buffer_signal_event, + .reset_event = iree_hal_rocm_stream_command_buffer_reset_event, + .wait_events = iree_hal_rocm_stream_command_buffer_wait_events, + .discard_buffer = iree_hal_rocm_stream_command_buffer_discard_buffer, + .fill_buffer = iree_hal_rocm_stream_command_buffer_fill_buffer, + .update_buffer = iree_hal_rocm_stream_command_buffer_update_buffer, + .copy_buffer = iree_hal_rocm_stream_command_buffer_copy_buffer, + .collective = iree_hal_rocm_stream_command_buffer_collective, + .push_constants = iree_hal_rocm_stream_command_buffer_push_constants, + .push_descriptor_set = + iree_hal_rocm_stream_command_buffer_push_descriptor_set, + .dispatch = iree_hal_rocm_stream_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_rocm_stream_command_buffer_dispatch_indirect, + .execute_commands = + iree_hal_rocm_stream_command_buffer_execute_commands, +}; diff --git a/experimental/rocm/stream_command_buffer.h b/experimental/rocm/stream_command_buffer.h new file mode 100644 index 000000000000..691fa63809ff --- /dev/null +++ b/experimental/rocm/stream_command_buffer.h @@ -0,0 +1,49 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ +#define IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ + +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" +#include "experimental/rocm/context_wrapper.h" +#include "experimental/rocm/rocm_headers.h" +#include "experimental/rocm/dynamic_symbols.h" +#include "experimental/rocm/tracing.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a ROCM stream command buffer that immediately issues commands against +// the given |stream|. Access to |stream| must be synchronized by the user. +// +// If |block_pool| is non-NULL then the stream command buffer will retain copies +// of input data until reset. If NULL then the caller must ensure the lifetime +// of input data outlives the command buffer. +// +// This command buffer is used to both replay deferred command buffers and +// perform inline execution. When replaying the scratch data required for things +// like buffer updates is retained by the source deferred command buffer and as +// such the |block_pool| and can be NULL to avoid a double copy. +iree_status_t iree_hal_rocm_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_tracing_context_t* tracing_context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hipStream_t stream, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns true if |command_buffer| is a ROCM stream-based command buffer. +bool iree_hal_rocm_stream_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ From 3422cbae58a6bec832794fda8181460d6f6fe740 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Mon, 18 Sep 2023 02:38:26 -0700 Subject: [PATCH 27/38] [ROCM] Add supports_concurrent_managed_access --- experimental/rocm/dynamic_symbol_tables.h | 1 + experimental/rocm/rocm_allocator.c | 120 ++++++++++++++++++---- experimental/rocm/rocm_allocator.h | 2 + experimental/rocm/rocm_device.c | 1 + 4 files changed, 106 insertions(+), 18 deletions(-) diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index b0ee67dba5a8..4214c2b92775 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -36,6 +36,7 @@ RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind, hipStream_t) RC_PFN_DECL(hipMalloc, void **, size_t) RC_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int) +RC_PFN_DECL(hipMemPrefetchAsync, const void *, size_t, int, hipStream_t) RC_PFN_DECL(hipFree, void *) RC_PFN_DECL(hipHostFree, void *) RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int) diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c index 3c63c71ec1bd..84dfb32f81c8 100644 --- a/experimental/rocm/rocm_allocator.c +++ b/experimental/rocm/rocm_allocator.c @@ -15,8 +15,11 @@ typedef struct iree_hal_rocm_allocator_t { iree_hal_resource_t resource; - iree_hal_device_t* base_device; iree_hal_rocm_context_wrapper_t* context; + hipDevice_t device; + hipStream_t stream; + + bool supports_concurrent_managed_access; IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;) } iree_hal_rocm_allocator_t; @@ -30,10 +33,30 @@ static iree_hal_rocm_allocator_t* iree_hal_rocm_allocator_cast( } iree_status_t iree_hal_rocm_allocator_create( - iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_context_wrapper_t* context, hipDevice_t device, hipStream_t stream, iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(context); IREE_TRACE_ZONE_BEGIN(z0); + + // To support device-local + host-visible memory we need concurrent managed + // access indicating that the host and devices can concurrently access the + // device memory. If we don't have this feature then we fall back to forcing + // all device-local + host-visible memory into host-local + device-visible + // page-locked memory. The compiler tries to avoid this for high-traffic + // buffers except for readback staging buffers. + int supports_concurrent_managed_access = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, ROCM_RESULT_TO_STATUS( + context->syms, + hipDeviceGetAttribute( + &supports_concurrent_managed_access, + hipDeviceAttributeConcurrentManagedAccess, device), + "hipDeviceGetAttribute")); + IREE_TRACE_ZONE_APPEND_TEXT( + z0, supports_concurrent_managed_access + ? "has CONCURRENT_MANAGED_ACCESS" + : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on " + "device-local + host-visible memory)"); iree_hal_rocm_allocator_t* allocator = NULL; iree_status_t status = iree_allocator_malloc( context->host_allocator, sizeof(*allocator), (void**)&allocator); @@ -41,6 +64,9 @@ iree_status_t iree_hal_rocm_allocator_create( iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable, &allocator->resource); allocator->context = context; + allocator->device = device; + allocator->stream = stream; + allocator->supports_concurrent_managed_access = supports_concurrent_managed_access !=0; *out_allocator = (iree_hal_allocator_t*)allocator; } @@ -87,24 +113,31 @@ static iree_status_t iree_hal_rocm_allocator_query_memory_heaps( iree_host_size_t capacity, iree_hal_allocator_memory_heap_t* IREE_RESTRICT heaps, iree_host_size_t* IREE_RESTRICT out_count) { - const iree_host_size_t count = 3; + iree_hal_rocm_allocator_t* allocator = + iree_hal_rocm_allocator_cast(base_allocator); + + // TODO(benvanik): check CU_DEVICE_ATTRIBUTE_INTEGRATED and return a unified + // set of heaps (likely still a cached and uncached, at minimum). + iree_host_size_t count = 3; + if (allocator->supports_concurrent_managed_access) { + ++count; // device-local | host-visible + } if (out_count) *out_count = count; if (capacity < count) { // NOTE: lightweight as this is hit in normal pre-sizing usage. return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); } - // NOTE: this is all a guess - someone who is familiar with rocm will want - // to refine this further. - // Don't think there's a query for these. // Max allocation size may be much smaller in certain memory types such as // page-locked memory and it'd be good to enforce that. const iree_device_size_t max_allocation_size = ~(iree_device_size_t)0; const iree_device_size_t min_alignment = 64; + int i = 0; + // Device-local memory (dispatch resources): - heaps[0] = (iree_hal_allocator_memory_heap_t){ + heaps[i++] = (iree_hal_allocator_memory_heap_t){ .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH, @@ -112,27 +145,46 @@ static iree_status_t iree_hal_rocm_allocator_query_memory_heaps( .min_alignment = min_alignment, }; + if (allocator->supports_concurrent_managed_access) { + // Device-local managed memory with host mapping support: + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + } + // Write-combined page-locked host-local memory (upload): - heaps[1] = (iree_hal_allocator_memory_heap_t){ - .type = - IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_COHERENT, - .allowed_usage = - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, .max_allocation_size = max_allocation_size, .min_alignment = min_alignment, }; // Cached page-locked host-local memory (download): - heaps[2] = (iree_hal_allocator_memory_heap_t){ - .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_COHERENT | IREE_HAL_MEMORY_TYPE_HOST_CACHED, - .allowed_usage = - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, .max_allocation_size = max_allocation_size, .min_alignment = min_alignment, }; + IREE_ASSERT(i == count); return iree_ok_status(); } @@ -141,22 +193,46 @@ iree_hal_rocm_allocator_query_buffer_compatibility( iree_hal_allocator_t* IREE_RESTRICT base_allocator, iree_hal_buffer_params_t* IREE_RESTRICT params, iree_device_size_t* IREE_RESTRICT allocation_size) { + iree_hal_rocm_allocator_t* allocator = + iree_hal_rocm_allocator_cast(base_allocator); + // All buffers can be allocated on the heap. iree_hal_buffer_compatibility_t compatibility = IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; - if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { - compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + // Buffers are importable in ROCM under most cases, though performance may + // vary wildly. We don't fully verify that the buffer parameters are + // self-consistent and just look at whether we can get a device pointer. + if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE; } // Buffers can only be used on the queue if they are device visible. if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) { compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; } } + // If concurrent managed access is not supported then make device-local + + // host-visible allocations fall back to host-local + device-visible + // page-locked memory. This will be significantly slower for the device to + // access but the compiler only uses this type for readback staging buffers + // and it's better to function than function fast. + if (!allocator->supports_concurrent_managed_access && + iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE; + params->type &= ~(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE); + params->type |= + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE; + } + // We are now optimal. params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL; @@ -209,6 +285,14 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer( status = ROCM_RESULT_TO_STATUS( allocator->context->syms, hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); + if (iree_status_is_ok(status) && + allocator->supports_concurrent_managed_access) { + // Prefetch the buffer on the GPU device. + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemPrefetchAsync(device_ptr, allocation_size, allocator->device, + allocator->stream)); + } host_ptr = (void*)device_ptr; } else { // Device only. diff --git a/experimental/rocm/rocm_allocator.h b/experimental/rocm/rocm_allocator.h index a2a89eab2cdd..c735e830b013 100644 --- a/experimental/rocm/rocm_allocator.h +++ b/experimental/rocm/rocm_allocator.h @@ -19,6 +19,8 @@ extern "C" { // Create a ROCM allocator. iree_status_t iree_hal_rocm_allocator_create( iree_hal_rocm_context_wrapper_t* context, + hipDevice_t device, + hipStream_t stream, iree_hal_allocator_t** out_allocator); #ifdef __cplusplus diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index 9e508d6b8907..24da6bd7da66 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c @@ -138,6 +138,7 @@ static iree_status_t iree_hal_rocm_device_create_internal( } if (iree_status_is_ok(status)) { status = iree_hal_rocm_allocator_create(&device->context_wrapper, + device->device, device->stream, &device->device_allocator); } if (iree_status_is_ok(status) && From f6593c81387088127e2876284625cd57f82e209e Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 18 Sep 2023 16:10:07 -0400 Subject: [PATCH 28/38] Set preferred location to the device for HIP Managed Memory The semantics for specifying different kinds of advice is unclear so I set it in two stages. --- experimental/rocm/dynamic_symbol_tables.h | 1 + experimental/rocm/rocm_allocator.c | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index 4214c2b92775..b28acef5471b 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -25,6 +25,7 @@ RC_PFN_DECL(hipInit, unsigned int) RC_PFN_DECL(hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **) +RC_PFN_DECL(hipMemAdvise, const void *, size_t, int, int) RC_PFN_DECL(hipMemset, void *, int, size_t) RC_PFN_DECL(hipMemsetAsync, void *, int, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD32Async, void *, int, size_t, hipStream_t) diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c index 84dfb32f81c8..dbd0ea4b9936 100644 --- a/experimental/rocm/rocm_allocator.c +++ b/experimental/rocm/rocm_allocator.c @@ -285,6 +285,16 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer( status = ROCM_RESULT_TO_STATUS( allocator->context->syms, hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); + if (iree_status_is_ok(status)) { + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetPreferredLocation, allocator->device)); + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetCoarseGrain, allocator->device)); + } if (iree_status_is_ok(status) && allocator->supports_concurrent_managed_access) { // Prefetch the buffer on the GPU device. From 4bd693805089dc761b37eaa4e1d3bf78eda522a5 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Mon, 2 Oct 2023 10:30:50 -0400 Subject: [PATCH 29/38] [LLVMCPU] Add support for dynamic quantization + reassociation of grouped qmm MegaPR [LLVMCPU] Allow parallel tiling in LLVMCPUSplitReduction, tile reduction by 2 This commit enables tiling of parallel dimensions in LLVMCPUSplitReduction, as well as changing the tile size of the resulting reduction to 2. The latter change is an x86 specific optimization that allows targeting specific instructions through VectorContractCustomKernels. [LLVMCPU] Add support for vecmat cases in VectorContractCustomKernel This commit introduces some new functionality to VectorContractCustomKernels: 1. Matching for vecmat kernels that have 1D vector shapes 2. Support for `vector.contract` ops with split reduction dimensions 3. Ability to allow promoting smaller bitwidth inputs with `arith.extui` or `arith.extsi` before passing into the `llvm.inline_asm` op 4. Ability to specify explicit constraint strings per register input in a VectorContractCustomKernel 5. Support for `i4` and `i8` input types 6. New x86 AVX512VNNI i16xi16->i32 vecmat kernel with split reduction This commit also adds `vector.transfer_read` flattening patterns and VectorContractCustomKernel lowering patterns to LLVMCPUVectorLowering. [LLVMCPU] Add pass to breakdown subbyte `arith.extui` This pass breaks down `arith.extui` ops that have `i4` inputs into a sequence of `vector.shuffle->arith.andi->arith.shrui`. This avoids bad lowering of subbyte extends in x86 backend. This pass is somewhat specific to some work on vecmat VectorContractCustomKernels right now, and has some unique matchings. The pass also attempts to make use of AVX512 registers, so the vector size for the resulting IR is hardcoded as 512 bits. This needs to change before landing. This pass in general needs some refactoring before landing. [LLVMCPU] Add pass to fold away unit dimensions on `vector.contract` ops This pass folds away unit dimensions on `vector.contract` ops to get these ops into a form that is recognizable by the VectorContractCustomKernels patterns. This pass also hoists `vector.shape_cast` ops out of containing `scf.for` ops if possible when the shape cast operates on the accumulator of a `vector.contract` op. This pattern may be better off somewhere else, but for now it is here because the unit dim folding pattern can produce a hoistable `vector.shape_cast` op in cases with split reduction. [LLVMCPU] Add flag to restrict reassociated quantized matmul optimizations [LLVMCPU] Add additional Memref alias foldings [LLVMCPU] Simplify VectorContractCustomKernels x86 constraint codes, add new AVX512 kernel --- .../iree/compiler/Codegen/LLVMCPU/BUILD.bazel | 3 + .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 3 + .../LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp | 387 ++++++++++++++++++ .../LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp | 283 +++++++++++++ .../LLVMCPUFoldVectorContractUnitDims.cpp | 354 ++++++++++++++++ .../Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp | 56 ++- .../Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp | 39 ++ .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 33 +- .../iree/compiler/Codegen/LLVMCPU/Passes.h | 20 +- .../iree/compiler/Codegen/LLVMCPU/Passes.td | 25 ++ .../LLVMCPU/VectorContractCustomKernels.cpp | 383 ++++++++++++++--- .../compiler/GlobalOptimization/Passes.cpp | 4 +- 12 files changed, 1511 insertions(+), 79 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 8161510f8627..c045131f87f2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -53,8 +53,11 @@ iree_compiler_cc_library( "KernelDispatch.cpp", "LLVMCPUAssignConstantOrdinals.cpp", "LLVMCPUAssignImportOrdinals.cpp", + "LLVMCPUBreakDownSubbyteExtend.cpp", "LLVMCPUCheckIRBeforeLLVMConversion.cpp", "LLVMCPUEmitVectorizationRemarks.cpp", + "LLVMCPUFoldMemRefAliasOps.cpp", + "LLVMCPUFoldVectorContractUnitDims.cpp", "LLVMCPULinkExecutables.cpp", "LLVMCPULowerExecutableTarget.cpp", "LLVMCPUMmt4dVectorLowering.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 1250c4b17b06..c7278c37999e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -54,8 +54,11 @@ iree_cc_library( "KernelDispatch.cpp" "LLVMCPUAssignConstantOrdinals.cpp" "LLVMCPUAssignImportOrdinals.cpp" + "LLVMCPUBreakDownSubbyteExtend.cpp" "LLVMCPUCheckIRBeforeLLVMConversion.cpp" "LLVMCPUEmitVectorizationRemarks.cpp" + "LLVMCPUFoldMemRefAliasOps.cpp" + "LLVMCPUFoldVectorContractUnitDims.cpp" "LLVMCPULinkExecutables.cpp" "LLVMCPULowerExecutableTarget.cpp" "LLVMCPUMmt4dVectorLowering.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp new file mode 100644 index 000000000000..6d9b90384145 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp @@ -0,0 +1,387 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-breakdown-subbyte-extend" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { +namespace { + +template +static Value shuffleMaskShift(PatternRewriter &rewriter, Location loc, + SmallVector shuffleInputs, + int64_t srcBitWidth, int64_t vectorSize) { + auto shuffleInType = llvm::cast(shuffleInputs[0].getType()); + auto shuffleResultType = + VectorType::get({vectorSize}, shuffleInType.getElementType()); + int64_t dstBitWidth = shuffleInType.getElementTypeBitWidth(); + T maskBase = (1u << srcBitWidth) - 1; + + SmallVector maskArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + maskArray[elemNum] = maskBase << (elemNum * srcBitWidth % dstBitWidth); + } + auto maskVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, maskArray)); + LDBG("maskVals: " << maskVals); + SmallVector shruiArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + shruiArray[elemNum] = elemNum * srcBitWidth % dstBitWidth; + } + auto shruiVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + + int64_t dstSize = vectorSize * shuffleInputs.size(); + auto newVectorType = + VectorType::get({dstSize}, shuffleResultType.getElementType()); + Value newVector = rewriter.create( + loc, newVectorType, rewriter.getZeroAttr(newVectorType)); + + for (auto shuffleIn : llvm::enumerate(shuffleInputs)) { + SmallVector shuffleArray(vectorSize); + for (int64_t elemNum = 0; elemNum < vectorSize; elemNum++) { + shuffleArray[elemNum] = + elemNum / (vectorSize / shuffleInType.getNumElements()); + } + Value shuffleResult = rewriter.create( + loc, shuffleIn.value(), shuffleIn.value(), shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value andResult = + rewriter.create(loc, shuffleResult, maskVals); + LDBG("andResult: " << andResult); + + Value shruiResult = + rewriter.create(loc, andResult, shruiVals); + LDBG("shruiResult: " << shruiResult); + + int64_t offset = shuffleIn.index() * vectorSize; + newVector = rewriter.create( + loc, shruiResult, newVector, offset, 1); + } + return newVector; +} + +static std::optional> +getLoadsForExtend(arith::ExtUIOp extOp) { + Value extSource = extOp.getIn(); + auto shapeCastOp = extSource.getDefiningOp(); + if (!shapeCastOp) { + return std::nullopt; + } + Value shapeCastSource = shapeCastOp.getSource(); + auto insertOp = shapeCastSource.getDefiningOp(); + if (!insertOp) { + return std::nullopt; + } + SmallVector loads; + while (insertOp) { + Value insert = insertOp.getSource(); + auto insertShapeCastOp = insert.getDefiningOp(); + if (!insertShapeCastOp) { + return std::nullopt; + } + auto loadOp = insertShapeCastOp.getSource().getDefiningOp(); + if (!loadOp) { + return std::nullopt; + } + loads.push_back(loadOp.getResult()); + insertOp = insertOp.getDest().getDefiningOp(); + } + return loads; +} + +struct BreakDownSubbyteExtend final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + + SmallVector sources{extOp.getIn()}; + if (auto loads = getLoadsForExtend(extOp)) { + sources = *loads; + } + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(dstElemBitwidth) || srcElemBitwidth != 4) + return failure(); + + if (dstElemBitwidth != 32 && dstElemBitwidth != 16) { + return failure(); + } + + int64_t vectorSizeBits = 512; + int64_t vectorSize = vectorSizeBits / dstElemBitwidth; + int64_t shuffleInputSizeBits = vectorSize * srcElemBitwidth; + int64_t shuffleInputSize = shuffleInputSizeBits / dstElemBitwidth; + auto shuffleInputType = + VectorType::get({shuffleInputSize}, extuiDstType.getElementType()); + Value shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + SmallVector shuffleInputs; + + for (int sourceIdx = 0; sourceIdx < sources.size(); sourceIdx++) { + Value source = sources[sourceIdx]; + VectorType sourceType = llvm::cast(source.getType()); + SmallVector sourceShape(sourceType.getShape()); + int64_t innerSize = sourceShape.back(); + if (!llvm::isPowerOf2_64(innerSize)) { + return failure(); + } + for (int64_t i = 0; i < sourceType.getNumElements() / innerSize; i++) { + SmallVector indices; + int64_t numElems = i; + SmallVector sourceOuterShape(sourceShape.begin(), + sourceShape.end() - 1); + for (int64_t size : llvm::reverse(sourceOuterShape)) { + indices.push_back(numElems % size); + numElems /= size; + } + std::reverse(indices.begin(), indices.end()); + + Value innerSlice; + if (indices.size()) { + innerSlice = rewriter.create(extOp.getLoc(), + source, indices); + } else { + innerSlice = source; + } + VectorType innerSliceType = + llvm::cast(innerSlice.getType()); + int64_t numExtractedBits = + innerSliceType.getNumElements() * srcElemBitwidth; + if (numExtractedBits / dstElemBitwidth < 1) { + LDBG("extract not big enough: " << numExtractedBits / + dstElemBitwidth); + return failure(); + } + auto bitCastType = VectorType::get({numExtractedBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, innerSlice); + LDBG("innerSlice: " << innerSlice); + // LDBG("bitCastResult: " << bitCastResult); + + if (numExtractedBits >= shuffleInputSizeBits) { + for (int64_t extractOffset = 0; + extractOffset < numExtractedBits / dstElemBitwidth; + extractOffset += shuffleInputSize) { + Value extractedSlice = + rewriter.create( + extOp.getLoc(), bitCastResult, extractOffset, + shuffleInputSize, 1); + shuffleInputs.push_back(extractedSlice); + LDBG("extractedSlice: " << extractedSlice); + // vector = + // rewriter.create(extOp.getLoc(), + // extractedSlice, vector, SmallVector{offset}, + // SmallVector{1}); + } + } else { + int64_t offset = + i * numExtractedBits / dstElemBitwidth % shuffleInputSize; + shuffleInput = rewriter.create( + extOp.getLoc(), bitCastResult, shuffleInput, + SmallVector{offset}, SmallVector{1}); + if (offset + numExtractedBits / dstElemBitwidth == shuffleInputSize) { + shuffleInputs.push_back(shuffleInput); + shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + } + } + } + } + + Value newVector; + if (dstElemBitwidth == 32) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } else if (dstElemBitwidth == 16) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } + rewriter.replaceOpWithNewOp(extOp, extuiDstType, + newVector); + + return success(); + } +}; + +struct BreakDownSubbyteExtendFlatten final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + LDBG("extuiSrcType: " << extuiSrcType); + LDBG("extuiDstType: " << extuiDstType); + + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(extuiSrcType.getNumElements())) + return failure(); + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + LDBG("srcElemBitwidth: " << srcElemBitwidth); + LDBG("dstElemBitwidth: " << dstElemBitwidth); + + int64_t numBits = srcElemBitwidth * extuiSrcType.getNumElements(); + if (numBits / dstElemBitwidth < 1) { + return failure(); + } + + VectorType flattenedType = VectorType::get({extuiSrcType.getNumElements()}, + extuiSrcType.getElementType()); + Value shapeCastFlatten = rewriter.create( + extOp.getLoc(), flattenedType, extOp.getIn()); + + auto bitCastType = VectorType::get({numBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, shapeCastFlatten); + LDBG("bitCastResult: " << bitCastResult); + + SmallVector shuffleArray(extuiDstType.getNumElements()); + for (int64_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shuffleArray[elemNum] = elemNum / (extuiDstType.getNumElements() / + bitCastType.getNumElements()); + } + + Value shuffleResult = rewriter.create( + extOp.getLoc(), bitCastResult, bitCastResult, shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value shapeCastUnflatten = rewriter.create( + extOp.getLoc(), extuiDstType, shuffleResult); + Value maskVals, shruiVals; + if (dstElemBitwidth == 32) { + int32_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else if (dstElemBitwidth == 16) { + int16_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else { + return failure(); + } + + Value andResult = rewriter.create( + extOp.getLoc(), shapeCastUnflatten, maskVals); + LDBG("andResult: " << andResult); + + rewriter.replaceOpWithNewOp(extOp, andResult, shruiVals); + + return success(); + } +}; + +struct LLVMCPUBreakDownSubbyteExtendPass final + : public LLVMCPUBreakDownSubbyteExtendBase< + LLVMCPUBreakDownSubbyteExtendPass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + { + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + // For the case when the innermost dimension of the src type is too small to + // fill a single element of the dst type. + // { + // RewritePatternSet patterns(context); + // patterns.add(context); + // vector::populateVectorShapeCastLoweringPatterns(patterns); + // if (failed(applyPatternsAndFoldGreedily(getOperation(), + // std::move(patterns)))) { + // return signalPassFailure(); + // } + // } + } +}; + +} // namespace + +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass() { + return std::make_unique(); +} + +void populateLLVMCPUBreakDownSubbyteExtendPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp new file mode 100644 index 000000000000..fc8c40dfb2e2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp @@ -0,0 +1,283 @@ +//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass folds loading/storing from/to subview ops into +// loading/storing from/to the original memref. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-memref-alias-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir { +namespace iree_compiler { + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Merges expand_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfExpandShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges collapse_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfCollapseShapeOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +static SmallVector +calculateExpandedAccessIndices(AffineMap affineMap, + const SmallVector &indices, Location loc, + PatternRewriter &rewriter) { + SmallVector indicesOfr(llvm::to_vector( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }))); + SmallVector expandedIndices; + for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, affineMap.getSubMap({i}), indicesOfr); + expandedIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return expandedIndices; +} + +static LogicalResult +resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, + memref::ExpandShapeOp expandShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + // The below implementation uses computeSuffixProduct method, which only + // allows int64_t values (i.e., static shape). Bail out if it has dynamic + // shapes. + if (!expandShapeOp.getResultType().hasStaticShape()) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + for (ArrayRef groups : expandShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + int64_t groupSize = groups.size(); + + // Construct the expression for the index value w.r.t to expand shape op + // source corresponding the indices wrt to expand shape op result. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + SmallVector dims(groupSize); + bindDimsList(ctx, MutableArrayRef{dims}); + AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); + + /// Apply permutation and create AffineApplyOp. + SmallVector dynamicIndices(groupSize); + for (int64_t i = 0; i < groupSize; i++) + dynamicIndices[i] = indices[groups[i]]; + + // Creating maximally folded and composd affine.apply composes better with + // other transformations without interleaving canonicalization passes. + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/groupSize, + /*numSymbols=*/0, srcIndexExpr), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return success(); +} + +static LogicalResult +resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, + memref::CollapseShapeOp collapseShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + int64_t cnt = 0; + SmallVector tmp(indices.size()); + SmallVector dynamicIndices; + for (ArrayRef groups : collapseShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + dynamicIndices.push_back(indices[cnt++]); + int64_t groupSize = groups.size(); + + // Calculate suffix product for all collapse op source dimension sizes. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + + // Derive the index values along all dimensions of the source corresponding + // to the index wrt to collapsed shape op output. + auto d0 = rewriter.getAffineDimExpr(0); + SmallVector delinearizingExprs = delinearize(d0, suffixProduct); + + // Construct the AffineApplyOp for each delinearizingExpr. + for (int64_t i = 0; i < groupSize; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, + delinearizingExprs[i]), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + dynamicIndices.clear(); + } + if (collapseShapeOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + int64_t srcRank = + cast(collapseShapeOp.getViewSource().getType()).getRank(); + for (int64_t i = 0; i < srcRank; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, zeroAffineMap, dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + } + return success(); +} + +/// Helpers to access the memref operand for each op. +template +static Value getMemRefOperand(LoadOrStoreOpTy op) { + return op.getMemref(); +} + +static Value getMemRefOperand(vector::TransferReadOp op) { + return op.getSource(); +} + +static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } + +template +LogicalResult LLVMCPULoadOpOfExpandShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto expandShapeOp = + getMemRefOperand(loadOp).template getDefiningOp(); + + if (!expandShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesExpandShape( + loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), expandShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult LLVMCPULoadOpOfCollapseShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto collapseShapeOp = getMemRefOperand(loadOp) + .template getDefiningOp(); + + if (!collapseShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesCollapseShape( + loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), collapseShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +void populateLLVMCPUFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { + patterns.add, + LLVMCPULoadOpOfCollapseShapeOpFolder>( + patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct LLVMCPUFoldMemRefAliasOpsPass final + : public LLVMCPUFoldMemRefAliasOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void LLVMCPUFoldMemRefAliasOpsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateFoldMemRefAliasOpPatterns(patterns); + populateLLVMCPUFoldMemRefAliasOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir \ No newline at end of file diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp new file mode 100644 index 000000000000..12388c2dd8f1 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp @@ -0,0 +1,354 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- LLVMCPUFoldVectorContractUnitDims.cpp - Pass to fold unit dims of +// vector.contract ops -===// +// +// Patterns to fold away unit dimensions on `vector.contract` ops +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-unit-reduction-dims" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { + +// Given a `vector.contract` op and a set of indices to fold, this op rewrites +// the `vector.contract` op with surrounding `vector.shape_cast` ops to fold +// away the indicated indices. +static FailureOr +dropFoldableUnitIndices(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector foldIndices) { + SmallVector contractShape = *contractOp.getShapeForUnroll(); + SmallVector iteratorTypes = + contractOp.getIteratorTypesArray(); + auto indexingMaps = contractOp.getIndexingMapsArray(); + SmallVector> dstShapes; + SmallVector> dstExprs; + SmallVector inputs( + {contractOp.getLhs(), contractOp.getRhs(), contractOp.getAcc()}); + llvm::SetVector foldableDims; + for (int64_t dim : foldIndices) + foldableDims.insert(dim); + + for (AffineMap map : indexingMaps) { + SmallVector dstShape; + SmallVector dstExpr; + for (const auto &expr : enumerate(map.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (!foldableDims.contains(dimExpr.getPosition())) { + dstShape.push_back(contractShape[dimExpr.getPosition()]); + unsigned numSkipped = 0; + for (int64_t ind : foldIndices) { + if (dimExpr.getPosition() > ind) { + numSkipped++; + } + } + dstExpr.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition() - numSkipped)); + } + } else { + return failure(); + } + } + dstShapes.push_back(dstShape); + dstExprs.push_back(dstExpr); + } + + SmallVector newInputs; + SmallVector newIndexingMaps; + SmallVector newIteratorTypes; + for (auto iter : enumerate(iteratorTypes)) { + if (!foldableDims.contains(iter.index())) { + newIteratorTypes.push_back(iter.value()); + } + } + + for (int i = 0; i < 3; i++) { + // Shape unchanged + if (dstShapes[i].size() == indexingMaps[i].getResults().size()) { + newInputs.push_back(inputs[i]); + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newIndexingMaps.push_back(newIndexingMap); + continue; + } + if (dstShapes[i].size() == 0) { + return failure(); + } + VectorType inputVecType = llvm::cast(inputs[i].getType()); + VectorType dstType = + VectorType::get(dstShapes[i], inputVecType.getElementType()); + + Value result; + auto extsiop = inputs[i].getDefiningOp(); + auto extuiop = inputs[i].getDefiningOp(); + if (!extsiop && !extuiop) { + result = rewriter.create(contractOp.getLoc(), + dstType, inputs[i]); + } else { + Value extIn = extsiop ? extsiop.getIn() : extuiop.getIn(); + VectorType extInType = llvm::dyn_cast(extIn.getType()); + VectorType shapeCastOutType = + VectorType::get(dstType.getShape(), extInType.getElementType()); + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), shapeCastOutType, extIn); + result = extsiop ? rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult() + : rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult(); + } + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newInputs.push_back(result); + newIndexingMaps.push_back(newIndexingMap); + } + auto newContract = + rewriter + .create( + contractOp.getLoc(), newInputs[0], newInputs[1], newInputs[2], + rewriter.getAffineMapArrayAttr(newIndexingMaps), + rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( + newIteratorTypes, + [&](vector::IteratorType t) -> mlir::Attribute { + return vector::IteratorTypeAttr::get(rewriter.getContext(), + t); + })))) + .getResult(); + return newContract; +} + +// This pattern matches on a `vector.contract` op with unit size dimensions, and +// folds these dimensions away +class DropVectorContractUnitDims final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + LDBG("vector.contract op:\n" << contractOp); + VectorType outputType = + llvm::dyn_cast(contractOp.getAcc().getType()); + if (!outputType) { + return failure(); + } + + auto parentOp = contractOp->getParentOfType(); + if (parentOp) { + return failure(); + } + + auto iteratorTypes = contractOp.getIteratorTypesArray(); + SmallVector contractDims = *contractOp.getShapeForUnroll(); + unsigned numParallel = 0; + unsigned numReduction = 0; + SmallVector unitParallelDims; + SmallVector unitReductionDims; + SmallVector foldableDims; + for (auto size : enumerate(contractDims)) { + if (iteratorTypes[size.index()] == vector::IteratorType::parallel) { + numParallel++; + if (size.value() == 1) { + unitParallelDims.push_back(size.index()); + } + } else { + numReduction++; + if (size.value() == 1) { + unitReductionDims.push_back(size.index()); + } + } + } + if (numReduction && numReduction == unitReductionDims.size()) { + foldableDims.append(unitReductionDims.begin(), + unitReductionDims.end() - 1); + } else { + foldableDims.append(unitReductionDims.begin(), unitReductionDims.end()); + } + if (numParallel && numParallel == unitParallelDims.size()) { + foldableDims.append(unitParallelDims.begin() + 1, unitParallelDims.end()); + } else { + foldableDims.append(unitParallelDims.begin(), unitParallelDims.end()); + } + if (!foldableDims.size()) { + return failure(); + } + + FailureOr maybeNewContract = + dropFoldableUnitIndices(rewriter, contractOp, foldableDims); + if (failed(maybeNewContract)) { + return failure(); + } + Value newContract = maybeNewContract.value(); + LDBG("Replaced vector.contract:\n" << newContract); + + VectorType newOutputType = + llvm::dyn_cast(newContract.getType()); + if (outputType != newOutputType) { + // Reshape output of new vector.contract if needed + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), outputType, newContract); + rewriter.replaceOp(contractOp, shapeCastResult); + } else { + rewriter.replaceOp(contractOp, newContract); + } + + return success(); + } +}; + +// This pattern matches on a sequence of +// `vector.shape_cast->vector.contract->vector.shape_cast` within an `scf.for` +// op, where the shape cast ops are casting an argument of the `scf.for` op and +// the yielded result of the `scf.for` op. Once matched, the `vector.shape_cast` +// ops are hoisted out of the `scf.for` op. +class HoistShapeCastOutOfSCFFor final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + LDBG("forOp:\n" << forOp); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + std::optional> + hoistableShapeCast = std::nullopt; + int initArgIdx; + for (Value result : yieldOp.getOperation()->getOperands()) { + auto outputShapeCastOp = result.getDefiningOp(); + if (!outputShapeCastOp) { + continue; + } + LDBG("outputShapeCastOp:\n" << outputShapeCastOp); + auto contractOp = + outputShapeCastOp.getSource().getDefiningOp(); + if (!contractOp) { + continue; + } + LDBG("contractOp:\n" << contractOp); + Value acc = contractOp.getAcc(); + auto inputShapeCastOp = acc.getDefiningOp(); + if (!inputShapeCastOp) { + continue; + } + LDBG("inputShapeCastOp:\n" << inputShapeCastOp); + Value input = inputShapeCastOp.getSource(); + auto blockArg = dyn_cast(input); + if (!blockArg) { + continue; + } + LDBG("blockArg:\n" << blockArg); + hoistableShapeCast = std::make_pair(inputShapeCastOp, outputShapeCastOp); + initArgIdx = blockArg.getArgNumber() - 1; + } + + if (!hoistableShapeCast) { + return failure(); + } + vector::ShapeCastOp inSC = hoistableShapeCast->first; + vector::ShapeCastOp outSC = hoistableShapeCast->second; + SmallVector forOpInitArgs = forOp.getInitArgs(); + Value source = forOpInitArgs[initArgIdx]; + Value sourceSC = + rewriter + .create(forOp.getLoc(), inSC.getType(), source) + .getResult(); + forOpInitArgs[initArgIdx] = sourceSC; + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), forOpInitArgs); + LDBG("newForOp:\n" << newForOp); + rewriter.mergeBlocks(forOp.getBody(), newForOp.getBody(), + newForOp.getBody()->getArguments()); + auto newYieldOp = cast(newForOp.getBody()->getTerminator()); + LDBG("newYieldOp:\n" << newYieldOp); + SmallVector newForOpResults = + newYieldOp.getOperation()->getOperands(); + int contractResultIndex; + for (auto result : llvm::enumerate(newForOpResults)) { + if (result.value() == outSC.getResult()) { + newForOpResults[result.index()] = outSC.getSource(); + contractResultIndex = result.index(); + } + } + rewriter.updateRootInPlace(newYieldOp, [&]() { + newYieldOp.getOperation()->setOperands(newForOpResults); + }); + LDBG("newForOp with body:\n" << newForOp); + SmallVector newResults = newForOp.getResults(); + Value hoistedOutputShapeCast = + rewriter + .create(forOp.getLoc(), outSC.getType(), + newResults[contractResultIndex]) + .getResult(); + LDBG("hoistedOutputShapeCast:\n" << hoistedOutputShapeCast); + newResults[contractResultIndex] = hoistedOutputShapeCast; + rewriter.replaceOp(forOp, newResults); + + return success(); + } +}; + +namespace { +struct LLVMCPUFoldVectorContractUnitDimsPass + : public LLVMCPUFoldVectorContractUnitDimsBase< + LLVMCPUFoldVectorContractUnitDimsPass> { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; +} // namespace + +void LLVMCPUFoldVectorContractUnitDimsPass::runOnOperation() { + Operation *funcOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet foldUnitDimsPatterns(context); + foldUnitDimsPatterns + .add(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(foldUnitDimsPatterns)))) { + return signalPassFailure(); + } +} + +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass() { + return std::make_unique(); +} + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp index b0b36909f8fc..f00907c52a5e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp @@ -38,8 +38,9 @@ namespace { /// TODO: support named ops, numInputs > 1, and modify lastDim check below /// accordingly. If fpReductionReordering is not enabled by default, it must /// be an integer or index type to proceed to allow associative reordering. -LogicalResult splitReductionPrecondition(Operation *op, - bool fpReductionReordering) { +LogicalResult +splitReductionPrecondition(Operation *op, bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { linalg::LinalgOp linalgOp = cast(op); if (!linalgOp.hasTensorSemantics()) { @@ -63,7 +64,11 @@ LogicalResult splitReductionPrecondition(Operation *op, LLVM_DEBUG(llvm::dbgs() << "is not a generic op\n"); return failure(); } - if (linalgOp.getNumDpsInputs() != 1) { + if (enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() > 2) { + LLVM_DEBUG(llvm::dbgs() << "doesn't have at most 2 inputs\n"); + return failure(); + } + if (!enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() != 1) { LLVM_DEBUG(llvm::dbgs() << "doesn't have exactly 1 input\n"); return failure(); } @@ -102,8 +107,10 @@ LogicalResult splitReductionPrecondition(Operation *op, /// Converts an inner-reduction into outer reduction + inner-parallel dimension, /// followed by simple inner reduction. -LogicalResult splitReductionImpl(Operation *op, int64_t size, +LogicalResult splitReductionImpl(Operation *op, SmallVector tileSizes, + bool enableQuantizedMatmulReassociation, RewriterBase &rewriter) { + int64_t size = tileSizes.back(); IRRewriter::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(op); linalg::LinalgOp linalgOp = cast(op); @@ -119,8 +126,19 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, auto numLoops = linalgOp.getNumLoops(); // 1) Tile to extract a single vector-length array. - SmallVector tileSizesSVFirst(numLoops, - rewriter.getIndexAttr(1)); + SmallVector tileSizesSVFirst; + if (enableQuantizedMatmulReassociation) { + for (auto &s : tileSizes) { + if (!s) { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(1)); + } else { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(s)); + } + } + } else { + tileSizesSVFirst = + SmallVector(numLoops, rewriter.getIndexAttr(1)); + } tileSizesSVFirst[numLoops - 1] = rewriter.getIndexAttr(0); auto options = scf::SCFTilingOptions().setTileSizes(tileSizesSVFirst); FailureOr tileResFirst = scf::tileUsingSCFForOp( @@ -147,7 +165,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, rewriter.getIndexAttr(0)); // The reduction happens only in the penultimate dimension, which we now // tile. - tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + if (enableQuantizedMatmulReassociation) { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(2); + } else { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + } options = scf::SCFTilingOptions().setTileSizes(tileSizesSV); FailureOr tileRes = scf::tileUsingSCFForOp( rewriter, cast(splitRes->splitLinalgOp.getOperation()), @@ -164,8 +186,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, class LLVMCPUSplitReductionPass : public LLVMCPUSplitReductionBase { public: - LLVMCPUSplitReductionPass(bool fpReductionReordering) { + LLVMCPUSplitReductionPass(bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { this->enableFpReductionReordering = fpReductionReordering; + this->enableQuantizedMatmulReassociation = + enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -183,8 +208,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { funcOp.walk([&](linalg::GenericOp op) { candidates.push_back(op); }); for (auto genericOp : candidates) { LLVM_DEBUG(llvm::dbgs() << "candidate: " << genericOp << "\n"); - if (failed(splitReductionPrecondition(genericOp, - enableFpReductionReordering))) { + if (failed( + splitReductionPrecondition(genericOp, enableFpReductionReordering, + enableQuantizedMatmulReassociation))) { continue; } @@ -208,8 +234,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { "skip SplitReduction"); continue; } - int64_t size = reductionSizes.back(); - if (failed(splitReductionImpl(genericOp, size, rewriter))) { + if (failed(splitReductionImpl(genericOp, reductionSizes, + enableQuantizedMatmulReassociation, + rewriter))) { return signalPassFailure(); } } @@ -218,9 +245,10 @@ void LLVMCPUSplitReductionPass::runOnOperation() { } // namespace std::unique_ptr> -createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering) { +createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering, + const bool enableQuantizedMatmulReassociation) { return std::make_unique( - enableFpReductionReordering); + enableFpReductionReordering, enableQuantizedMatmulReassociation); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp index 65be5e78bd03..2668c15ec999 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp @@ -45,6 +45,8 @@ class LLVMCPUVectorLoweringPass LLVMCPUVectorLoweringPass(const LLVMCPUVectorLoweringPassOptions &options) { this->splitVectorTransfersTo = options.splitVectorTransfersTo; this->lowerVectorTransposeToAVX2 = options.lowerVectorTransposeToAVX2; + this->enableQuantizedMatmulReassociation = + options.enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -77,6 +79,26 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { .setVectorMultiReductionLowering(vectorMultiReductionLowering) .setVectorTransferSplit(vectorTransferSplit); + { + if (enableQuantizedMatmulReassociation) { + // Special-case vector.contract codegen paths. This needs to happen + // just before the generic vector ops lowerings. + RewritePatternSet patterns(ctx); + auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + populateVectorContractCustomKernelsPatterns(target, patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After custom kernel lowering for " + "vector.contract ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + { RewritePatternSet patterns(ctx); vector::populateVectorGatherLoweringPatterns(patterns); @@ -173,6 +195,23 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { llvm::dbgs() << "\n\n"; }); + // Break down subbyte `arith.extui` ops + { + if (enableQuantizedMatmulReassociation) { + RewritePatternSet patterns(&getContext()); + populateLLVMCPUBreakDownSubbyteExtendPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After breaking down subbyte extend ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + // 'vector.shape_cast' are very expensive operations that are even generated // by some of the lowerings above (e.g., transpose lowering). There are // chances to cancel them out if they are not lowered too early so we lower diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index b0fbf23e2cca..aea660c51bf2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -79,6 +79,13 @@ static llvm::cl::opt clInstrumentMemoryAccesses{ "instrumentation is enabled."), llvm::cl::init(false)}; +static llvm::cl::opt clEnableQuantizedMatmulReassociation( + "iree-llvmcpu-enable-quantized-matmul-reassociation", + llvm::cl::desc( + "Enables LLVMCPU codegen optimizations specific to reassociated " + "quantized matmuls (experimental)."), + llvm::cl::init(false)); + static void addTileAndDistributePasses(OpPassManager &pm) { pm.addPass(createTileAndDistributeToWorkgroupsPass()); auto &nestedModulePM = pm.nest(); @@ -174,6 +181,19 @@ LogicalResult verifyDoubleTilingExpertPassPipelineConfig( << index << "-th tile size set"; } } + // if (!clEnableQuantizedMatmulReassociation) { + // SmallVector thirdLevelTileSizes; + // std::tie(thirdLevelTileSizes, std::ignore) = + // tilingConfig.getVectorReductionSizes(); + // for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { + // if (tileSize != 0 && pLoopsSet.contains(index)) { + // return op->emitOpError("expected only reduction dims to be set in " + // "the third tiling " + // "level, got ") + // << index << "-th tile size set"; + // } + // } + // } } // Verify interchange @@ -348,7 +368,9 @@ void addMultiTilingExpertPassPipeline( // Run SplitReductionPass before the final reduction Fuse pass, because // SplitReductionPass takes care of banked-tiling. nestedModulePM.addNestedPass( - createLLVMCPUSplitReductionPass(clEnableReassociateFpReductions)); + createLLVMCPUSplitReductionPass( + clEnableReassociateFpReductions, + clEnableQuantizedMatmulReassociation)); nestedModulePM.addNestedPass(createLLVMCPUTilePass(i)); continue; } @@ -385,11 +407,17 @@ void addMultiTilingExpertPassPipeline( // Run IREE specific passes before vector lowering expert. nestedModulePM.addNestedPass( createRemoveSingleIterationLoopPass()); + if (clEnableQuantizedMatmulReassociation) { + nestedModulePM.addNestedPass( + createLLVMCPUFoldVectorContractUnitDimsPass()); + } { LLVMCPUVectorLoweringPassOptions options; options.lowerVectorTransposeToAVX2 = lowerToAVX2; options.splitVectorTransfersTo = "linalg-copy"; + options.enableQuantizedMatmulReassociation = + clEnableQuantizedMatmulReassociation; nestedModulePM.addNestedPass( createLLVMCPUVectorLoweringPass(options)); } @@ -649,6 +677,9 @@ static void addLowerToLLVMPasses(OpPassManager &passManager, passManager.addNestedPass(arith::createArithExpandOpsPass()); passManager.addNestedPass(memref::createExpandOpsPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); + if (clEnableQuantizedMatmulReassociation) { + passManager.addPass(createLLVMCPUFoldMemRefAliasOpsPass()); + } passManager.addPass(createEmulateNarrowTypePass()); passManager.addPass(createCanonicalizerPass()); passManager.addPass(createCSEPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 75b82de33a1d..cc00401b9bbb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -19,6 +19,10 @@ namespace mlir::iree_compiler { class TilingConfig; +// Pass to breakdown subbyte extui +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass(); + /// Performs the final conversion to LLVM dialect. std::unique_ptr> createConvertToLLVMPass(bool reassociateFpReordering = false); @@ -55,8 +59,9 @@ createLLVMCPUMmt4dVectorLoweringPass(); std::unique_ptr> createLLVMCPUPeelPass(); /// Pass to perform SplitReduction transformations of `LinalgOp`s. -std::unique_ptr> -createLLVMCPUSplitReductionPass(bool enableReassociateFpReductions = false); +std::unique_ptr> createLLVMCPUSplitReductionPass( + bool enableReassociateFpReductions = false, + bool enableQuantizedMatmulReassociation = false); /// Synchronizes LLVM linkage with MLIR symbol visibility. std::unique_ptr> @@ -82,6 +87,7 @@ std::unique_ptr> createLLVMCPUUnfuseFMAOpsPass(); struct LLVMCPUVectorLoweringPassOptions { std::string splitVectorTransfersTo = ""; bool lowerVectorTransposeToAVX2 = false; + bool enableQuantizedMatmulReassociation = false; }; std::unique_ptr> createLLVMCPUVectorLoweringPass(); std::unique_ptr> createLLVMCPUVectorLoweringPass( @@ -96,6 +102,11 @@ createVectorContractCustomKernelsPass(); std::unique_ptr> createVerifyLinalgTransformLegalityPass(); +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass(); + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass(); + //------------------------------------------------------------------------------ // LLVMCPU Codegen specific patterns. //------------------------------------------------------------------------------ @@ -108,6 +119,11 @@ void populateUnfusedFMAOpsPassPatterns(MLIRContext *context, void populateVectorContractCustomKernelsPatterns( IREE::HAL::ExecutableTargetAttr target, RewritePatternSet &patterns); +void populateLLVMCPUBreakDownSubbyteExtendPatterns(RewritePatternSet &patterns); + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context); + //----------------------------------------------------------------------------// // LLVMCPU backend Pass Pipelines. //----------------------------------------------------------------------------// diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index 1491a3ba03be..a03f9cd575be 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -45,6 +45,11 @@ def LLVMCPUAssignImportOrdinals : let constructor = "mlir::iree_compiler::createLLVMCPUAssignImportOrdinalsPass()"; } +def LLVMCPUBreakDownSubbyteExtend : Pass<"iree-llvmcpu-breakdown-subbyte-extend", "func::FuncOp"> { + let summary = "Pass to break down subbyte extui ops."; + let constructor = "mlir::iree_compiler::createLLVMCPUBreakDownSubbyteExtendPass()"; +} + def LLVMCPUCheckIRBeforeLLVMConversion : Pass<"iree-llvmcpu-check-ir-before-llvm-conversion", "ModuleOp"> { let summary = "Checks CPU backend specific IR constraints (like no allocas)"; @@ -58,6 +63,20 @@ def LLVMCPUEmitVectorizationRemarks : "mlir::iree_compiler::createLLVMCPUEmitVectorizationRemarksPass()"; } +def LLVMCPUFoldVectorContractUnitDims : + Pass<"iree-llvmcpu-fold-vector-contract-unit-dims", "func::FuncOp"> { + let summary = "Fold unit dims on vector.contract ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldVectorContractUnitDimsPass()"; +} + +def LLVMCPUFoldMemRefAliasOps : + Pass<"iree-llvmcpu-fold-memref-alias-ops", ""> { + let summary = "Fold combinations of memref ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldMemRefAliasOpsPass()"; +} + def LLVMCPULinkExecutables : Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> { let summary = "Links LLVMCPU HAL executables within the top-level program module."; @@ -103,6 +122,9 @@ def LLVMCPUSplitReduction : Pass<"iree-llvmcpu-split-reduction", "func::FuncOp"> Option<"enableFpReductionReordering", "enable-fp-reduction-reordering", "bool", /*default=*/"false", "Flag to enable reduction reordering on floating points.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", + "bool", /*default=*/"false", + "Flag to enable optimizations for reassociated quantized matmuls.">, ]; } @@ -162,6 +184,9 @@ def LLVMCPUVectorLowering : Option<"lowerVectorTransposeToAVX2", "lower-vector-transpose-to-avx2", "bool", /*default=*/"false", "Add specific transpose to avx2 lowering patterns.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool", + /*default=*/"false", + "Add specific patterns for optimizing reassociated quantized matmuls.">, ]; let constructor = "mlir::iree_compiler::createLLVMCPUVectorLoweringPass()"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index 9c22ca1078fd..202e17521a84 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -25,6 +25,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-vector-contract-custom-kernels" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir::iree_compiler { namespace { @@ -85,6 +89,73 @@ static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) { return true; } +static bool isVectorTimesMatrixTransposed(vector::ContractionOp contractionOp, + int64_t splitSize) { + // Check that the reduction is additive. + if (contractionOp.getKind() != vector::CombiningKind::ADD) { + return false; + } + // Check that there are 1 parallel and 1 reduction iterators. + unsigned numIters = splitSize ? 3 : 2; + auto iteratorTypes = contractionOp.getIteratorTypes().getValue(); + if (iteratorTypes.size() != numIters) { + return false; + } + SmallVector parallelIterators; + SmallVector reductionIterators; + for (int i = 0; i < numIters; i++) { + if (vector::isParallelIterator(iteratorTypes[i])) { + parallelIterators.push_back(i); + } else if (vector::isReductionIterator(iteratorTypes[i])) { + reductionIterators.push_back(i); + } else { + return false; + } + } + if (parallelIterators.size() != numIters - 1 || + reductionIterators.size() != 1) { + return false; + } + // Give the found iterators some idiomatic names. + const int NIter = parallelIterators[0]; + const int KIter = reductionIterators[0]; + const int SplitIter = splitSize ? parallelIterators[1] : 0; + // Check that there are 3 indexing maps. + auto indexingMaps = contractionOp.getIndexingMapsArray(); + if (indexingMaps.size() != 3) { + return false; + } + // Check that the indexing maps have the expected form. + SmallVector> expectedMapResults; + if (splitSize) { + SmallVector> res = { + {KIter, SplitIter}, {NIter, KIter, SplitIter}, {NIter, SplitIter}}; + expectedMapResults = res; + numIters = 3; + } else { + SmallVector> res = {{KIter}, {NIter, KIter}, {NIter}}; + expectedMapResults = res; + numIters = 2; + } + for (int m = 0; m < 3; ++m) { + auto map = indexingMaps[m]; + auto expectedResults = expectedMapResults[m]; + if (map.getNumDims() != numIters || + map.getNumResults() != expectedResults.size()) { + return false; + } + for (int r = 0; r < expectedResults.size(); ++r) { + int actualMapResult = + map.getResults()[r].cast().getPosition(); + if (actualMapResult != expectedMapResults[m][r]) { + return false; + } + } + } + LDBG("passed isVectorTimesMatrixTransposed"); + return true; +} + // Returns true if `contractionOp` is of the form // matrix * transposed_matrix // where matrix is a vector<{mSize}x{kSize}xType>, and @@ -131,6 +202,31 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, return false; } +static bool matchVMT(vector::ContractionOp contractionOp, int64_t mSize, + int64_t kSize, int64_t nSize, int splitSize, + bool *transpose = nullptr) { + if (mSize != 1) { + return false; + } + if (!isVectorTimesMatrixTransposed(contractionOp, splitSize)) { + return false; + } + VectorType lhsType = llvm::cast(contractionOp.getLhs().getType()); + VectorType rhsType = llvm::cast(contractionOp.getRhs().getType()); + auto lhsShape = lhsType.getShape(); + auto rhsShape = rhsType.getShape(); + if (splitSize && (lhsShape[1] != splitSize || rhsShape[2] != splitSize)) { + return false; + } + if (lhsShape[0] != kSize || rhsShape[1] != kSize) { + return false; + } + if (rhsShape[0] == nSize) { + return true; + } + return false; +} + // `promotedResult` is required to be a Vector. // If its VectorType does not have `promotedType` as its element type, or // the operand to the type-promotion op is not `unpromotedType` returns a null @@ -142,8 +238,9 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, // Note that this only looks at the immediately defining operation, so we likely // want to have earlier passes that sink widening operations as far down as // possible, which is probably just good regardless. -static Value getUnpromotedInput(Type unpromotedType, Type promotedType, - Value promotedResult) { +static Value getUnpromotedInput(PatternRewriter &rewriter, Type unpromotedType, + Type promotedType, Value promotedResult, + bool promoteSmallTypes = false) { VectorType promotedResultVectorType = llvm::cast(promotedResult.getType()); if (promotedResultVectorType.getElementType() != promotedType) { @@ -155,13 +252,29 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // TODO: handle promotion of floating point types. Not doing it for now as // it wouldn't be exercised. auto extSIOp = promotedResult.getDefiningOp(); - if (!extSIOp) { + auto extUIOp = promotedResult.getDefiningOp(); + if (!extSIOp && !extUIOp) { return nullptr; } - Value extInput = extSIOp.getIn(); + Value extInput = extSIOp ? extSIOp.getIn() : extUIOp.getIn(); if (llvm::cast(extInput.getType()).getElementType() != unpromotedType) { - return nullptr; + if (promoteSmallTypes) { + VectorType unpromotedVectorType = + VectorType::get(llvm::cast(extInput.getType()).getShape(), + unpromotedType); + return extSIOp + ? rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult() + : rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult(); + } else { + return nullptr; + } } return extInput; } @@ -169,12 +282,28 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // Helper to create a 1D, contiguous slice of a 1D vector. static Value extract1DSlice(PatternRewriter &rewriter, Location loc, VectorType dstVecType, Value input, int position) { - assert(input.getType().cast().getRank() == 1); assert(dstVecType.getRank() == 1); - std::array offsets{position}; - std::array strides{1}; - return rewriter.create( - loc, input, offsets, dstVecType.getShape(), strides); + if (input.getType().cast().getRank() == 1) { + SmallVector offsets({position}); + SmallVector strides({1}); + SmallVector sizes(dstVecType.getShape()); + return rewriter.create(loc, input, offsets, + sizes, strides); + } else { + SmallVector inputShape( + llvm::cast(input.getType()).getShape()); + assert(inputShape.back() == dstVecType.getNumElements()); + std::reverse(inputShape.begin(), inputShape.end()); + int currentPos = position; + SmallVector indices; + for (auto size : inputShape) { + indices.push_back(currentPos % size); + currentPos = currentPos / size; + } + std::reverse(indices.begin(), indices.end()); + return rewriter.create( + loc, input, SmallVector(indices.begin(), indices.end() - 1)); + } } // Helper to extract an element of a 1D vector. @@ -188,8 +317,12 @@ static Value extract(PatternRewriter &rewriter, Location loc, Value input, } // Helper to flatten a N-dimensional vector to a 1D vector. -static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { +static Value flattenImperfectSize(PatternRewriter &rewriter, Location loc, + Value vector, VectorType regVectorType) { VectorType inputVecType = llvm::cast(vector.getType()); + if (regVectorType.getNumElements() == inputVecType.getShape().back()) { + return vector; + } VectorType dstType = VectorType::get(inputVecType.getNumElements(), inputVecType.getElementType()); return rewriter.create(loc, dstType, vector); @@ -206,20 +339,31 @@ static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { // (2) Be explicit about the size of the vectors involved in the kernel's // "calling convention". struct MMTKernel { - enum class ScalarType : int8_t { None, I8, I32, F32 }; + enum class ScalarType : int8_t { None, I4, I8, I16, I32, F32 }; // Element type of the LHS vectors. ScalarType lhsType = ScalarType::None; // Element type of the RHS vectors. ScalarType rhsType = ScalarType::None; // Element type of the Accumulator and output vectors. ScalarType accType = ScalarType::None; + // Optional user defined constrained codes for input and output registers. + // This is useful when the constraint code is not the same for all operands. + std::optional> lhsCode = std::nullopt; + std::optional> rhsCode = std::nullopt; + std::optional> accCode = std::nullopt; + // This flag indicates whether or not to promote inputs that have a smaller + // bitwidth than lhsType, rhsType, or accType, to the appropriate bitwidth + bool promoteSmallTypes = false; // Number of rows of the LHS and Accumulator tile. - int8_t m0 = 0; + int16_t m0 = 0; // Reduction dimension, i.e. number of columns of the LHS. - int8_t k0 = 0; + int16_t k0 = 0; // Number of rows of the RHS (note that the operation being targeted, MMT, // is matrix multiplication with a *transposed* RHS) - int8_t n0 = 0; + int16_t n0 = 0; + // Size of the added parallel dimension when the vector.contract op has been + // split with splitReduction + int16_t split0 = 0; // Number of LHS elements in the type of register to be used for the LHS. // This is > 1 if SIMD registers are to be used. // Note: LHS/RHS/Accumulator may use registers of different sizes. @@ -235,6 +379,8 @@ struct MMTKernel { int8_t rhsRegs = 0; // Number of registers needed to hold the Accumulator. int8_t accRegs = 0; + // Indicates whether to use Intel or AT&T syntax + bool useIntel = false; // If not null, points to the inline asm code template for this kernel. // Register operands for the LHS, RHS and Accumulator are to be referenced as // $(lhs:), $(rhs:), $(acc:) respectively, where i is a decimal @@ -249,9 +395,15 @@ struct MMTKernel { const char *asmClobbers = nullptr; void validate() const { - assert(m0 * k0 == lhsRegSize * lhsRegs); // number of elements of LHS - assert(n0 * k0 == rhsRegSize * rhsRegs); // number of elements of RHS - assert(m0 * n0 == accRegSize * accRegs); // number of elements of Accum + assert(m0 * k0 == lhsRegSize * lhsRegs || + m0 * k0 * split0 == + lhsRegSize * lhsRegs); // number of elements of LHS + assert(n0 * k0 == rhsRegSize * rhsRegs || + n0 * k0 * split0 == + rhsRegSize * rhsRegs); // number of elements of RHS + assert(m0 * n0 == accRegSize * accRegs || + m0 * n0 * split0 == + accRegSize * accRegs); // number of elements of Accum assert(lhsType != ScalarType::None); assert(rhsType != ScalarType::None); assert(accType != ScalarType::None); @@ -673,13 +825,75 @@ MMTKernel MMTKernel_8x1x1_f32f32f32_Aarch64_Baseline_InlineAsm() { return kernel; } +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpdpwssd $(acc:0), $(rhs:0), $(lhs:0) + vpdpwssd $(acc:1), $(rhs:1), $(lhs:0) + vpdpwssd $(acc:2), $(rhs:2), $(lhs:0) + vpdpwssd $(acc:3), $(rhs:3), $(lhs:0) + )ASM"; + kernel.asmClobbers = ""; + return kernel; +} + +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpmaddwd zmm17, $(rhs:0), $(lhs:0) + vpmaddwd zmm18, $(rhs:1), $(lhs:0) + vpmaddwd zmm19, $(rhs:2), $(lhs:0) + vpmaddwd zmm20, $(rhs:3), $(lhs:0) + vpaddw $(acc:0), $(acc:0), zmm17 + vpaddw $(acc:1), $(acc:1), zmm18 + vpaddw $(acc:2), $(acc:2), zmm19 + vpaddw $(acc:3), $(acc:3), zmm20 + )ASM"; + kernel.asmClobbers = "zmm17,zmm18,zmm19,zmm20"; + return kernel; +} + // Constructs the mlir::Type corresponding to a scalar type. Type mlirType(MLIRContext *context, MMTKernel::ScalarType t) { switch (t) { case MMTKernel::ScalarType::None: break; + case MMTKernel::ScalarType::I4: + return IntegerType::get(context, 4, IntegerType::Signless); case MMTKernel::ScalarType::I8: return IntegerType::get(context, 8, IntegerType::Signless); + case MMTKernel::ScalarType::I16: + return IntegerType::get(context, 16, IntegerType::Signless); case MMTKernel::ScalarType::I32: return IntegerType::get(context, 32, IntegerType::Signless); case MMTKernel::ScalarType::F32: @@ -704,7 +918,7 @@ class MMTKernelGenerator { ArrayRef acc) { validateOperands(lhs, rhs, acc); if (kernel.asmImpl) { - return generateAsm(rewriter, loc, lhs, rhs, acc); + return generateAsm(rewriter, loc, lhs, rhs, acc, kernel.useIntel); } // In the future we may have alternate generator paths, e.g. 1D intrinsics // or other asm paths with a different interface, e.g. handling also @@ -754,10 +968,17 @@ class MMTKernelGenerator { validate(acc, kernel.accRegs, getAccRegVectorType()); } // Helper for generateAsmCodeAndConstraints - std::string getConstraintCode() const { + std::string + getConstraintCode(std::optional kernelConstraintCode) const { + if (kernelConstraintCode) { + return std::string(*kernelConstraintCode); + } if (isAArch64(target)) { return "w"; } + if (isX86(target)) { + return "v"; + } assert(false && "what constraint code to use on this arch?"); return {}; } @@ -819,31 +1040,39 @@ class MMTKernelGenerator { // processedIdx is the index of a register in the processed asm. // Example: $5 => processedIdx == 5 int processedIdx = 0; - auto processOperands = [&](Constraints::Kind constraintKind, - const char *name, int count) { - const std::string &constraintCode = getConstraintCode(); - // unprocessedIdx is the index of a register in the unprocessed asm. - // Example: $(lhs:1) => unprocessedIdx == 1 - for (int unprocessedIdx = 0; unprocessedIdx < count; - ++unprocessedIdx, ++processedIdx) { - constraints.add(constraintKind, constraintCode); - // Perform the code replacement for the operand. - // Example: $(lhs:1) => $5 - replaceAllSubstrsInPlace( - code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), - llvm::formatv("${0}", processedIdx)); - } - }; - processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs); - processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs); - processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs); + auto processOperands = + [&](Constraints::Kind constraintKind, const char *name, int count, + std::optional> kernelCodes) { + const std::string &constraintCode = getConstraintCode(std::nullopt); + // unprocessedIdx is the index of a register in the unprocessed asm. + // Example: $(lhs:1) => unprocessedIdx == 1 + for (int unprocessedIdx = 0; unprocessedIdx < count; + ++unprocessedIdx, ++processedIdx) { + if (kernelCodes) { + constraints.add(constraintKind, (*kernelCodes)[unprocessedIdx]); + } else { + constraints.add(constraintKind, constraintCode); + } + // Perform the code replacement for the operand. + // Example: $(lhs:1) => $5 + replaceAllSubstrsInPlace( + code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), + llvm::formatv("${0}", processedIdx)); + } + }; + processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs, + kernel.accCode); + processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs, + kernel.lhsCode); + processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs, + kernel.rhsCode); constraints.setClobbers(kernel.asmClobbers); constraintsString = constraints.toString(); } // Helper for generate(). Implements the asm path. SmallVector generateAsm(PatternRewriter &rewriter, Location loc, ArrayRef lhs, ArrayRef rhs, - ArrayRef acc) { + ArrayRef acc, bool useIntel) { SmallVector inputs; // First the input operands. Then the input-output operands, which, as far // as input constraints are concerned, are *tied* inputs, i.e. refer to @@ -863,9 +1092,13 @@ class MMTKernelGenerator { SmallVector outputOperandTypes( llvm::map_range(acc, [](Value v) { return v.getType(); })); auto returnType = - LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); + outputOperandTypes.size() == 1 + ? outputOperandTypes[0] + : LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); auto dialectAttr = - LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); + useIntel + ? LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_Intel) + : LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); std::string code; std::string constraints; generateAsmCodeAndConstraints(code, constraints); @@ -875,10 +1108,14 @@ class MMTKernelGenerator { /*operand_attrs=*/ArrayAttr()); // Extract result vectors from the asm op. SmallVector resVec; - for (int i = 0; i < kernel.accRegs; ++i) { - SmallVector position = {i}; - resVec.push_back( - rewriter.create(loc, asmOp.getRes(), position)); + if (outputOperandTypes.size() == 1) { + resVec.push_back(asmOp.getRes()); + } else { + for (int i = 0; i < kernel.accRegs; ++i) { + SmallVector position = {i}; + resVec.push_back(rewriter.create( + loc, asmOp.getRes(), position)); + } } return resVec; } @@ -913,7 +1150,9 @@ class MMTCustomKernelPattern : public OpRewritePattern { // Check if `contractionOp` matches, and obtain the (un-promoted) input // LHS and RHS vectors. bool transposeKernel = false; - if (!matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, + if (!matchVMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, kernel.split0, + &transposeKernel) && + !matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, &transposeKernel)) { return failure(); } @@ -928,9 +1167,11 @@ class MMTCustomKernelPattern : public OpRewritePattern { return failure(); } Value unpromotedLhs = - getUnpromotedInput(lhsElemType, accElemType, contractionOp.getLhs()); + getUnpromotedInput(rewriter, lhsElemType, accElemType, + contractionOp.getLhs(), kernel.promoteSmallTypes); Value unpromotedRhs = - getUnpromotedInput(rhsElemType, accElemType, contractionOp.getRhs()); + getUnpromotedInput(rewriter, rhsElemType, accElemType, + contractionOp.getRhs(), kernel.promoteSmallTypes); if (!unpromotedLhs || !unpromotedRhs) { return failure(); } @@ -952,9 +1193,23 @@ class MMTCustomKernelPattern : public OpRewritePattern { // `contractionOp` matches, start rewriting it. Location loc = contractionOp.getLoc(); // Flatten the inputs to 1D vectors. - Value flatLhs = flatten(rewriter, loc, unpromotedLhs); - Value flatRhs = flatten(rewriter, loc, unpromotedRhs); - Value flatAcc = flatten(rewriter, loc, contractionOp.getAcc()); + VectorType lhsRegVectorType = generator.getLhsRegVectorType(); + VectorType rhsRegVectorType = generator.getRhsRegVectorType(); + VectorType accRegVectorType = generator.getAccRegVectorType(); + Value lhs, rhs; + if (transposeKernel) { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, rhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, lhsRegVectorType); + } else { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, lhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, rhsRegVectorType); + } + Value acc = flattenImperfectSize(rewriter, loc, contractionOp.getAcc(), + accRegVectorType); // Slice into SIMD-register-sized 1D input vectors ready to feed to the // target SIMD instructions. auto sliceIntoRegVectors = [&](int regsCount, VectorType regVectorType, @@ -967,17 +1222,14 @@ class MMTCustomKernelPattern : public OpRewritePattern { } return regVectors; }; - VectorType lhsRegVectorType = generator.getLhsRegVectorType(); - VectorType rhsRegVectorType = generator.getRhsRegVectorType(); - VectorType accRegVectorType = generator.getAccRegVectorType(); - Value flatLhsForKernel = transposeKernel ? flatRhs : flatLhs; - Value flatRhsForKernel = transposeKernel ? flatLhs : flatRhs; + Value lhsForKernel = transposeKernel ? rhs : lhs; + Value rhsForKernel = transposeKernel ? lhs : rhs; SmallVector lhsRegVectors = - sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, flatLhsForKernel); + sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, lhsForKernel); SmallVector rhsRegVectors = - sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, flatRhsForKernel); + sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, rhsForKernel); SmallVector accRegVectors = - sliceIntoRegVectors(kernel.accRegs, accRegVectorType, flatAcc); + sliceIntoRegVectors(kernel.accRegs, accRegVectorType, acc); // Generate the kernel! SmallVector resRegVectors = generator.generate( rewriter, loc, lhsRegVectors, rhsRegVectors, accRegVectors); @@ -1036,8 +1288,8 @@ struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics return failure(); } - Value inLhs = getUnpromotedInput(I8Type, I32Type, lhs); - Value inRhs = getUnpromotedInput(I8Type, I32Type, rhs); + Value inLhs = getUnpromotedInput(rewriter, I8Type, I32Type, lhs); + Value inRhs = getUnpromotedInput(rewriter, I8Type, I32Type, rhs); if (!inLhs || !inRhs) return failure(); @@ -1170,6 +1422,15 @@ void populateVectorContractCustomKernelsPatterns( patterns.add( context, MMTKernel_8x8x8_i8i8i32_Aarch64I8mm_InlineAsm()); } + } else if (isX86(target)) { + if (hasFeature(target, "+avx512vnni")) { + patterns.add( + context, + MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm()); + } else if (hasFeature(target, "+avx512bw")) { + patterns.add( + context, MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm()); + } } } diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 9c1ddf178628..88ce7a1672cf 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -68,7 +68,9 @@ void buildGlobalOptimizationPassPipeline( .addPass(createRemoveZeroExtentTensorsPass) .addPass(createDetachElementwiseFromNamedOpsPass) .addPass(mlir::createLinalgNamedOpConversionPass) - .addPass(createConvert1X1FilterConv2DToMatmulPass); + .addPass(createConvert1X1FilterConv2DToMatmulPass) + .addPredicatedPass(!clEnableQuantizedMatmulReassociation, + createLiftGenericToTransposeBatchMatmulPass); mainPassManager.addPass(createEraseUnusedLinalgOperands()); // Expand tensor shapes into SSA values and optimize the whole program. From 9c914987953fd3b4f43b7ed8f359e98362b8f95c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 19 Oct 2023 09:18:38 -0500 Subject: [PATCH 30/38] LLVMCPU reduction tiling rebase conflict fix Co-authored-by: Max Dawkins --- .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index aea660c51bf2..381c76d9212b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -170,30 +170,19 @@ LogicalResult verifyDoubleTilingExpertPassPipelineConfig( } } - SmallVector thirdLevelTileSizes; - std::tie(thirdLevelTileSizes, std::ignore) = - tilingConfig.getVectorReductionSizes(); - for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { - if (tileSize != 0 && pLoopsSet.contains(index)) { - return op->emitOpError( - "expected only reduction dims to be set in the third tiling " - "level, got ") - << index << "-th tile size set"; + if (!clEnableQuantizedMatmulReassociation) { + SmallVector thirdLevelTileSizes; + std::tie(thirdLevelTileSizes, std::ignore) = + tilingConfig.getVectorReductionSizes(); + for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { + if (tileSize != 0 && pLoopsSet.contains(index)) { + return op->emitOpError("expected only reduction dims to be set in " + "the third tiling " + "level, got ") + << index << "-th tile size set"; + } } } - // if (!clEnableQuantizedMatmulReassociation) { - // SmallVector thirdLevelTileSizes; - // std::tie(thirdLevelTileSizes, std::ignore) = - // tilingConfig.getVectorReductionSizes(); - // for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { - // if (tileSize != 0 && pLoopsSet.contains(index)) { - // return op->emitOpError("expected only reduction dims to be set in " - // "the third tiling " - // "level, got ") - // << index << "-th tile size set"; - // } - // } - // } } // Verify interchange From 5d3f9def11ea84a3a3e3ba6b2e3f6072698ed5bb Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 13 Sep 2023 13:44:02 -0400 Subject: [PATCH 31/38] [LLVMCPU] Add tiling config pass for special non-root op tiling cases This commit adds a new tiling configuration pass in LLVMCPU. This pass sets a special tiling configuration for reassociated quantized matmuls, since the non-root op of these dispatches require specific tiling to target certain x86 instructions. This pass is a place to set abnormal tile sizes on non-root ops for specific types of workloads. --- .../iree/compiler/Codegen/LLVMCPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 + .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 4 +- .../iree/compiler/Codegen/LLVMCPU/Passes.h | 3 + .../iree/compiler/Codegen/LLVMCPU/Passes.td | 6 + .../LLVMCPU/SetSpecialTilingConfigs.cpp | 341 ++++++++++++++++++ 6 files changed, 355 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index c045131f87f2..eba4bef21317 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -71,6 +71,7 @@ iree_compiler_cc_library( "LLVMCPUUnfuseFMAOps.cpp", "LLVMCPUVectorLowering.cpp", "Passes.cpp", + "SetSpecialTilingConfigs.cpp", "TargetMLTransformInfo.cpp", "Utils.cpp", "VectorContractCustomKernels.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index c7278c37999e..e280dc0e3208 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -72,6 +72,7 @@ iree_cc_library( "LLVMCPUUnfuseFMAOps.cpp" "LLVMCPUVectorLowering.cpp" "Passes.cpp" + "SetSpecialTilingConfigs.cpp" "TargetMLTransformInfo.cpp" "Utils.cpp" "VectorContractCustomKernels.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 381c76d9212b..710a434cb6d2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -690,8 +690,10 @@ static void addLowerToLLVMPasses(OpPassManager &passManager, void buildLLVMCPUCodegenConfigurationPassPipeline(OpPassManager &passManager) { { - addCommonTargetExecutablePreprocessingPasses(passManager); OpPassManager &modulePassManager = passManager.nest(); + modulePassManager.addNestedPass( + createSetSpecialTilingConfigsPass()); + addCommonTargetExecutablePreprocessingPasses(passManager); modulePassManager.addNestedPass( createRematerializeParallelOpsPass()); // TODO(#13888): This(createExpandF16OpToF32Pass()) pass is being added way diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index cc00401b9bbb..8a370c510b5b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -107,6 +107,9 @@ createLLVMCPUFoldVectorContractUnitDimsPass(); std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass(); +std::unique_ptr> +createSetSpecialTilingConfigsPass(); + //------------------------------------------------------------------------------ // LLVMCPU Codegen specific patterns. //------------------------------------------------------------------------------ diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index a03f9cd575be..5f9a81baec00 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -198,6 +198,12 @@ def VectorContractCustomKernels : let constructor = "mlir::iree_compiler::createVectorContractCustomKernelsPass()"; } +def SetSpecialTilingConfigs : + Pass<"iree-llvmcpu-set-special-tiling-configs", "func::FuncOp"> { + let summary = "Set the tile sizes for special cases before KernelDispatch."; + let constructor = "mlir::iree_compiler::createSetSpecialTilingConfigsPass()"; +} + def VerifyLinalgTransformLegality : Pass<"iree-llvmcpu-verify-linalg-transform-legality", "ModuleOp"> { let summary = "Verify that only supported IR constructs are passed to the compiler."; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp new file mode 100644 index 000000000000..789d67a71351 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp @@ -0,0 +1,341 @@ +#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-set-special-tiling-configs" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { +namespace { + +static void setTileSizes(linalg::GenericOp intMatmul, + linalg::GenericOp reassociation, + func::FuncOp entryPointFn, + IREE::HAL::ExecutableTargetAttr target) { + int mDistSize = 1; + int nDistSize = 128; + int mSize = 1; + int nSize = 4; + int kSize = 8; + int groupSize = 1; + SmallVector mDims; + SmallVector nDims; + SmallVector kDims; + SmallVector groupDims; + SmallVector maps = intMatmul.getIndexingMapsArray(); + int lhs = 0; + int rhs = 1; + int out = 2; + auto hasDim = [&](int mapIdx, int dimIdx) -> bool { + return llvm::any_of(maps[mapIdx].getResults(), [&](AffineExpr res) { + auto expr = res.dyn_cast(); + return expr && expr.getPosition() == dimIdx; + }); + }; + for (int dim = 0; dim < intMatmul.getNumLoops(); dim++) { + if (hasDim(lhs, dim) && hasDim(rhs, dim) && hasDim(out, dim)) { + groupDims.push_back(dim); + } else if (hasDim(lhs, dim) && hasDim(rhs, dim) && !hasDim(out, dim)) { + kDims.push_back(dim); + } else if (hasDim(lhs, dim) && !hasDim(rhs, dim) && hasDim(out, dim)) { + mDims.push_back(dim); + } else if (!hasDim(lhs, dim) && hasDim(rhs, dim) && hasDim(out, dim)) { + nDims.push_back(dim); + } + } + if (hasFeature(target, "+avx512bw") || hasFeature(target, "+avx512vnni")) { + kSize = 16; + } + + if (mDims.size() > 1 || nDims.size() > 1 || kDims.size() != 1 || + kDims[0] != intMatmul.getNumLoops() - 1) { + return; + } + + SmallVector distTileSizes_mm(intMatmul.getNumLoops(), 0); + SmallVector parallelTileSizes_mm(intMatmul.getNumLoops(), 0); + SmallVector reductionTileSizes_mm(intMatmul.getNumLoops(), 0); + SmallVector lastTileSizes_mm(intMatmul.getNumLoops(), 0); + + SmallVector distTileSizes_re(reassociation.getNumLoops(), 0); + SmallVector parallelTileSizes_re(reassociation.getNumLoops(), 0); + SmallVector reductionTileSizes_re(reassociation.getNumLoops(), 0); + SmallVector lastTileSizes_re(reassociation.getNumLoops(), 0); + + for (int mDim : mDims) { + distTileSizes_mm[mDim] = mDistSize; + parallelTileSizes_mm[mDim] = mSize; + reductionTileSizes_mm[mDim] = mSize; + + distTileSizes_re[mDim] = mDistSize; + parallelTileSizes_re[mDim] = mSize; + } + for (int nDim : nDims) { + distTileSizes_mm[nDim] = nDistSize; + parallelTileSizes_mm[nDim] = nSize; + reductionTileSizes_mm[nDim] = nSize; + + distTileSizes_re[nDim] = nDistSize; + parallelTileSizes_re[nDim] = nSize; + } + for (int kDim : kDims) { + reductionTileSizes_mm[kDim] = kSize; + } + for (int groupDim : groupDims) { + reductionTileSizes_mm[groupDim] = groupSize; + } + + TileSizesListType tileSizes_mm; + tileSizes_mm.push_back(distTileSizes_mm); + tileSizes_mm.push_back(parallelTileSizes_mm); + tileSizes_mm.push_back(reductionTileSizes_mm); + tileSizes_mm.push_back(lastTileSizes_mm); + + TileSizesListType tileSizes_re; + tileSizes_re.push_back(distTileSizes_re); + tileSizes_re.push_back(parallelTileSizes_re); + tileSizes_re.push_back(reductionTileSizes_re); + tileSizes_re.push_back(lastTileSizes_re); + + IREE::Codegen::DispatchLoweringPassPipeline passPipeline = + IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert; + + MLIRContext *context = entryPointFn.getContext(); + auto config_mm = + IREE::Codegen::LoweringConfigAttr::get(context, tileSizes_mm); + intMatmul->setAttr("lowering_config", config_mm); + + auto config_re = + IREE::Codegen::LoweringConfigAttr::get(context, tileSizes_re); + auto translationInfo_re = IREE::Codegen::TranslationInfoAttr::get( + entryPointFn.getContext(), passPipeline, 0, 1); + auto compilationInfo_re = IREE::Codegen::CompilationInfoAttr::get( + context, config_re, translationInfo_re, ArrayRef({}), + std::nullopt); + + reassociation->setAttr("compilation_info", compilationInfo_re); + + return; +} + +static bool isIntegerMatmul(linalg::GenericOp genericOp) { + if (genericOp.getNumDpsInits() != 1) { + LDBG("Wrong number of outputs for matmul: " << genericOp.getNumDpsInits() + << "\n"); + return false; + } + if (genericOp.getNumDpsInputs() != 2) { + LDBG("Wrong number of inputs for matmul: " << genericOp.getNumDpsInputs() + << "\n"); + return false; + } + + unsigned numLoops = genericOp.getNumLoops(); + unsigned numReductionLoops = genericOp.getNumReductionLoops(); + if (numLoops != 3) { + LDBG("Wrong number of loops for matmul: " << numLoops << "\n"); + return false; + } + if (numReductionLoops != 1) { + LDBG("Wrong number of reduction loops for matmul: " << numReductionLoops + << "\n"); + return false; + } + // Work back from linalg.yield and check body of genericOp. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + Value producerOutput; + Operation *producer; + Operation *mulRhsProducer; + + // Producer of linalg.yield op is arith.addi + { + producerOutput = yieldOp->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.addi op is arith.muli + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.muli op RHS is arith.extui + { + producerOutput = producer->getOperand(1); + mulRhsProducer = producerOutput.getDefiningOp(); + if (!mulRhsProducer || mulRhsProducer->getNumOperands() == 0) + return false; + if (!matchPattern(mulRhsProducer, m_Op())) + return false; + } + + // Producer of arith.subf op LHS is arith.extsi + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + return true; +} + +static bool isReassociatedDequantizationOp(linalg::GenericOp genericOp) { + if (genericOp.getNumDpsInits() != 1) { + LDBG("Wrong number of outputs: " << genericOp.getNumDpsInits() << "\n"); + return false; + } + if (genericOp.getNumDpsInputs() != 5) { + LDBG("Wrong number of inputs: " << genericOp.getNumDpsInputs() << "\n"); + return false; + } + + unsigned numLoops = genericOp.getNumLoops(); + unsigned numReductionLoops = genericOp.getNumReductionLoops(); + if (numLoops != 2) { + LDBG("Wrong number of loops: " << numLoops << "\n"); + return false; + } + if (numReductionLoops != 1) { + LDBG("Wrong number of reduction loops: " << numReductionLoops << "\n"); + return false; + } + // Work back from linalg.yield and check body of genericOp. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + Value producerOutput; + Operation *producer; + Operation *subRhsProducer; + + // Producer of linalg.yield op is arith.addf + { + producerOutput = yieldOp->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.addf op is arith.subf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.subf op RHS is arith.mulf + { + producerOutput = producer->getOperand(1); + subRhsProducer = producerOutput.getDefiningOp(); + if (!subRhsProducer || subRhsProducer->getNumOperands() == 0) + return false; + if (!matchPattern(subRhsProducer, m_Op())) + return false; + } + + // Producer of arith.mulf from arith.subf RHS is arith.mulf + { + producerOutput = subRhsProducer->getOperand(0); + subRhsProducer = producerOutput.getDefiningOp(); + if (!subRhsProducer || subRhsProducer->getNumOperands() == 0) + return false; + if (!matchPattern(subRhsProducer, m_Op())) + return false; + } + + // Producer of arith.subf op LHS is arith.mulf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.mulf op is arith.mulf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.mulf op is arith.sitofp + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + return true; +} + +struct SetSpecialTilingConfigsPass + : public SetSpecialTilingConfigsBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + std::optional> + reassociatedQuantizedMatmulOps = std::nullopt; + for (auto genericOp : + funcOp.getFunctionBody().getOps()) { + if (isReassociatedDequantizationOp(genericOp)) { + auto intMatmulOp = + genericOp.getInputs()[0].getDefiningOp(); + if (intMatmulOp) { + if (isIntegerMatmul(intMatmulOp)) { + reassociatedQuantizedMatmulOps = + std::make_pair(intMatmulOp, genericOp); + break; + } + } + } + } + if (reassociatedQuantizedMatmulOps) { + setTileSizes(reassociatedQuantizedMatmulOps->first, + reassociatedQuantizedMatmulOps->second, funcOp, target); + } + } +}; +} // namespace + +std::unique_ptr> +createSetSpecialTilingConfigsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir From fbffcb55953c426e41f2ad5c515968dcc7f31aeb Mon Sep 17 00:00:00 2001 From: PhaneeshB Date: Wed, 15 Nov 2023 20:19:26 +0530 Subject: [PATCH 32/38] update deprecated AffineExpr casts --- .../Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp | 2 +- .../iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp | 2 +- .../compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp | 2 +- .../Preprocessing/Common/ConvertConvToChannelsLast.cpp | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp index 12388c2dd8f1..1096693c693d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp @@ -54,7 +54,7 @@ dropFoldableUnitIndices(PatternRewriter &rewriter, SmallVector dstShape; SmallVector dstExpr; for (const auto &expr : enumerate(map.getResults())) { - if (auto dimExpr = expr.value().dyn_cast()) { + if (auto dimExpr = llvm::dyn_cast(expr.value())) { if (!foldableDims.contains(dimExpr.getPosition())) { dstShape.push_back(contractShape[dimExpr.getPosition()]); unsigned numSkipped = 0; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp index 789d67a71351..0b016bd947eb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp @@ -37,7 +37,7 @@ static void setTileSizes(linalg::GenericOp intMatmul, int out = 2; auto hasDim = [&](int mapIdx, int dimIdx) -> bool { return llvm::any_of(maps[mapIdx].getResults(), [&](AffineExpr res) { - auto expr = res.dyn_cast(); + auto expr = llvm::dyn_cast(res); return expr && expr.getPosition() == dimIdx; }); }; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index 202e17521a84..3dca50496496 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -146,7 +146,7 @@ static bool isVectorTimesMatrixTransposed(vector::ContractionOp contractionOp, } for (int r = 0; r < expectedResults.size(); ++r) { int actualMapResult = - map.getResults()[r].cast().getPosition(); + llvm::cast(map.getResults()[r]).getPosition(); if (actualMapResult != expectedMapResults[m][r]) { return false; } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp index 4a8b833bfb1a..9c40792456a4 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp @@ -260,11 +260,11 @@ static TransposeIndices collectChannelTransposeIndices( AffineMap map, SmallVector> transposeDimTargets) { SmallVector channelIndices(transposeDimTargets.size()); for (auto [index, result] : llvm::enumerate(map.getResults())) { - if (result.isa()) { + if (llvm::isa(result)) { for (auto [channelVec, dimCategory] : llvm::zip_equal(channelIndices, transposeDimTargets)) { if (llvm::is_contained(dimCategory, - result.cast().getPosition())) { + llvm::cast(result).getPosition())) { channelVec.push_back(index); break; } From adfa15d7e576ce573458132262143b81af5c4d14 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 29 Nov 2023 11:02:05 -0500 Subject: [PATCH 33/38] add back fusion in reassociated quantized matmul --- .../GlobalOptimization/FuseDequantizationMatmul.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp index 7ec2388020b4..b70f4d056633 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp @@ -776,6 +776,14 @@ static LogicalResult reassociateDequantMatmul(RewriterBase &rewriter, rewriter.replaceOp(matmul, reassociatedDequantization.getResult(0)); + // Fuse dequantization + matmul ops into a single dispatch region + SmallVector dequantMatmulOps{quantizedIntegerMatmul, + reassociatedDequantization}; + FailureOr maybeDequantMatmulDispatch = + wrapConsecutiveOpsInDispatchRegion(rewriter, dequantMatmulOps); + if (failed(maybeDequantMatmulDispatch)) { + return failure(); + } return success(); } From e49584e7225003bb37b97dd0bf1dcffeca3029b8 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Fri, 1 Dec 2023 16:06:09 -0500 Subject: [PATCH 34/38] restrict lifting of generic ops to batch_matmul From d2403315f8e628479aaf8c37f3d29a2d3658d086 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Fri, 1 Dec 2023 16:30:24 -0500 Subject: [PATCH 35/38] check fusion in quantized matmul reassociation lit test --- .../GlobalOptimization/test/fuse_dequantization_matmul.mlir | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir index fd8d196bcebe..e535c52eafcb 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir @@ -97,6 +97,7 @@ module { // REASSOCIATE-CHECK: %[[INITMATMUL:.+]] = tensor.empty() : tensor<11008x32xi32> // REASSOCIATE-CHECK: %[[FILLMATMUL:.+]] = linalg.fill ins(%[[C0I32]] // REASSOCIATE-CHECK-SAME: outs(%[[INITMATMUL]] : +// REASSOCIATE-CHECK: %[[DISP:.+]] = flow.dispatch.region // REASSOCIATE-CHECK: %[[GENMATMUL:.+]] = linalg.generic // REASSOCIATE-CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]] // REASSOCIATE-CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] @@ -122,4 +123,5 @@ module { // REASSOCIATE-CHECK: %[[RESUBF:.+]] = arith.subf %[[REMULF1]], %[[REMULF3]] : f32 // REASSOCIATE-CHECK: %[[READDF:.+]] = arith.addf %[[RESUBF]], %[[REOUT0]] : f32 // REASSOCIATE-CHECK: linalg.yield %[[READDF]] : f32 -// REASSOCIATE-CHECK: return %[[GENREASSOCIATE]] +// REASSOCIATE-CHECK: flow.return %[[GENREASSOCIATE]] +// REASSOCIATE-CHECK: return %[[DISP]] From 81bfe263863a7939a224f1569695ce0bfe37626b Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 5 Dec 2023 19:24:41 -0500 Subject: [PATCH 36/38] [LLVMGPU] Add multi-row vector reduction configuration (#73) This is to speed up matvec. The new configuration is experimental and only applied on ROCm targets. --- .../Common/GPU/VectorReductionToGPU.cpp | 2 +- .../GPU/test/vector_reduction_to_gpu.mlir | 65 +++++++++++++++++++ .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 21 ++++++ .../Codegen/LLVMGPU/test/config_matvec.mlir | 47 ++++++++++++++ 4 files changed, 134 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp index 2c71aa9444fa..fcfec190ab41 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp @@ -195,7 +195,7 @@ class VectorReductionToGPUPass bool expandSubgroupReduction, std::function getWarpSize) : expandSubgroupReduction(expandSubgroupReduction), - getWarpSize(getWarpSize) {} + getWarpSize(std::move(getWarpSize)) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert>, vector<1xf32> // CHECK: vector.transfer_write {{.*}} : vector<1xf32>, memref<128x32xf32> // CHECK: return + + +// ----- + +// Check that we multi-row matvec gets distributed across subgoroup threads. + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}> +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @multirow { + hal.executable.variant @rocm target(#executable_target_rocm_hsaco_fb) { + hal.executable.export @multirow layout(#pipeline_layout) attributes { + workgroup_size = [64 : index, 1 : index, 1 : index] + } + builtin.module { + func.func @multirow() { + %cst = arith.constant dense<0.000000e+00> : vector<4x512xf16> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf16> + %c4096 = arith.constant 4096 : index + %c512 = arith.constant 512 : index + %cst_1 = arith.constant 0.000000e+00 : f16 + %id = gpu.thread_id x + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x4096xf16, #hal.descriptor_type> + memref.assume_alignment %0, 64 : memref<1x4096xf16, #hal.descriptor_type> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<32000x4096xf16, #hal.descriptor_type> + memref.assume_alignment %1, 64 : memref<32000x4096xf16, #hal.descriptor_type> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<1x32000xf16, #hal.descriptor_type> + memref.assume_alignment %2, 64 : memref<1x32000xf16, #hal.descriptor_type> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x] + %4 = scf.for %arg0 = %c0 to %c4096 step %c512 iter_args(%arg1 = %cst) -> (vector<4x512xf16>) { + %8 = vector.transfer_read %0[%c0, %arg0], %cst_1 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (0, d1)>} : memref<1x4096xf16, #hal.descriptor_type>, vector<4x512xf16> + %9 = vector.transfer_read %1[%3, %arg0], %cst_1 {in_bounds = [true, true]} : memref<32000x4096xf16, #hal.descriptor_type>, vector<4x512xf16> + %10 = arith.mulf %8, %9 : vector<4x512xf16> + %11 = arith.addf %arg1, %10 : vector<4x512xf16> + scf.yield %11 : vector<4x512xf16> + } + %5 = vector.broadcast %4 : vector<4x512xf16> to vector<1x4x512xf16> + %6 = vector.multi_reduction , %5, %cst_0 [2] : vector<1x4x512xf16> to vector<1x4xf16> + %7 = vector.extract %6[0] : vector<4xf16> from vector<1x4xf16> + vector.transfer_write %7, %2[%c0, %3] {in_bounds = [true]} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type> + return + } + } + } +} + +// CHECK-LABEL: func.func @multirow() { +// CHECK: scf.for {{.*}} -> (vector<4x8xf16>) { +// CHECK: vector.transfer_read {{.*}} : memref<32000x4096xf16, #hal.descriptor_type>, vector<4x8xf16> +// CHECK: vector.transfer_read {{.*}} : memref<1x4096xf16, #hal.descriptor_type>, vector<4x8xf16> +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<4x8xf16> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4x8xf16> +// CHECK: } +// CHECK: gpu.shuffle xor +// CHECK: scf.if {{.*}} { +// CHECK: vector.transfer_write {{.*}} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type> +// CHECK: } +// CHECK-NEXT: return diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ef26385f57ea..577c248cc09e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -22,7 +22,9 @@ #include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -924,6 +926,25 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if ((groupSize / subgroupSize) > subgroupSize) return failure(); + // With just one subgroup per workgroup, make each subgroup do more work and + // process a few reductions along the last parallel dimension. + // TODO: We should also check that this will result in data reuse for at least + // one argument. + // TODO: This is experimental for matvec (matmul_transpose_b) on rocm-only for + // now. + if (numDynamicReductionDims == 0 && numParallelDims == 2 && + isRocmTarget(entryPoint)) { + if (*parallelSize && !parallelDims.empty() && groupSize == subgroupSize) { + int maxParallelFactor = 4; // Keeping this conservative for now. + int64_t lastParallelBound = bounds[parallelDims.back()]; + if (!ShapedType::isDynamic(lastParallelBound) && + (lastParallelBound % maxParallelFactor == 0) && + lastParallelBound > maxParallelFactor) { + workgroupTileSizes.back() = maxParallelFactor; + } + } + } + std::array workgroupSize = {groupSize, 1, 1}; SmallVector reductionTileSizes(op.getNumLoops(), 0); int64_t remainingGroupSize = groupSize; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir index 47b34315160d..2cfa7a8b3aeb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir @@ -50,3 +50,50 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf // CHECK: func.func @dynamic_batch_matvec() // CHECK: linalg.batch_matmul // CHECK-SAME: lowering_config = #[[$CONFIG]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> + +hal.executable @vmt { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>) { + hal.executable.export @vmt layout(#pipeline_layout) + builtin.module { + func.func @vmt() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<1x4096xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32000x4096xf16> + %5 = tensor.empty() : tensor<1x32000xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<1x32000xf16>) -> tensor<1x32000xf16> + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<1x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<1x32000xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %out, %8 : f16 + linalg.yield %9 : f16 + } -> tensor<1x32000xf16> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 32000], strides = [1, 1] : tensor<1x32000xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-LABEL: hal.executable.export public @vmt +// CHECK-SAME: subgroup_size = 64 : index +// CHECK-SAME: translation_info = #[[$TRANSLATION]] +// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] +// CHECK: func.func @vmt() +// CHECK: linalg.generic +// CHECK-SAME: lowering_config = #[[$CONFIG]] From 3d9e3f32b4d0e9be57c683ff325f79e5ec392c60 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Wed, 13 Dec 2023 01:35:53 -0600 Subject: [PATCH 37/38] Single line syntax + introduce wrapConsectutiveOp header to fuseDequant. --- .../GlobalOptimization/FuseDequantizationMatmul.cpp | 1 + .../Preprocessing/Common/ConvertConvNchwToNhwc.cpp | 8 ++------ .../Preprocessing/Common/ConvertConvToChannelsLast.cpp | 8 ++------ .../Preprocessing/Common/GeneralizeConvolutions.cpp | 8 ++------ compiler/src/iree/compiler/Preprocessing/Common/Passes.td | 6 +++--- 5 files changed, 10 insertions(+), 21 deletions(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp index b70f4d056633..49d895540b29 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" +#include "iree/compiler/GlobalOptimization/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp index 5704f18ed466..fdf60f9e4493 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp @@ -18,9 +18,7 @@ #define DEBUG_TYPE "iree-flow-convert-conv-nchw-to-nhwc" -namespace mlir { -namespace iree_compiler { -namespace IREE { +namespace mlir::iree_compiler::Preprocessing { using TransposeIndices = SmallVector; @@ -559,6 +557,4 @@ createConvertConvNchwToNhwcPass() { return std::make_unique(); } -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp index 9c40792456a4..6ee82e170b3f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp @@ -23,9 +23,7 @@ #define DEBUG_TYPE "iree-preprocessing-convert-conv-to-channels-last" -namespace mlir { -namespace iree_compiler { -namespace IREE { +namespace mlir::iree_compiler::Preprocessing { static const StringLiteral fullTileTransposeMarker = "__fully_transpose_tile__"; @@ -885,6 +883,4 @@ std::unique_ptr createConvertConvToChannelsLastPass() { return std::make_unique(); } -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp index ada980a788ed..f3a56404375e 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp @@ -20,9 +20,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -namespace mlir { -namespace iree_compiler { -namespace IREE { +namespace mlir::iree_compiler::Preprocessing { namespace { @@ -63,6 +61,4 @@ std::unique_ptr createGeneralizeConvolutionsPass() { return std::make_unique(); } -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index a2fb9328965f..0f523fdcd3c9 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -19,7 +19,7 @@ def ConvertConvNchwToNhwc : InterfacePass<"iree-flow-convert-conv-nchw-to-nhwc", "mlir::FunctionOpInterface"> { let summary = "Convert linalg NCHW Convolutions to NHWC"; let constructor = - "mlir::iree_compiler::IREE::createConvertConvNchwToNhwcPass()"; + "mlir::iree_compiler::Preprocessing::createConvertConvNchwToNhwcPass()"; } def MakeSingleDispatchForFunction : @@ -31,14 +31,14 @@ def MakeSingleDispatchForFunction : def GeneralizeConvolutions : Pass<"iree-preprocessing-generalize-convolutions", ""> { let summary = "Generalize all convolution ops"; - let constructor = "mlir::iree_compiler::IREE::createGeneralizeConvolutionsPass()"; + let constructor = "mlir::iree_compiler::Preprocessing::createGeneralizeConvolutionsPass()"; } def ConvertConvToChannelsLast : Pass<"iree-preprocessing-convert-conv-to-channels-last", ""> { let summary = "Convert linalg convolutions to channels last."; let constructor = - "mlir::iree_compiler::IREE::createConvertConvToChannelsLastPass()"; + "mlir::iree_compiler::Preprocessing::createConvertConvToChannelsLastPass()"; let options = [ Option<"tileSize", "tile-size", "int", /*default=*/"0", From ada6732e998242c3e06953f60f0881f90811a6cd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 Jan 2024 17:53:16 +0000 Subject: [PATCH 38/38] Bump jinja2 from 2.11.3 to 3.1.3 in /build_tools/benchmarks/reporting Bumps [jinja2](https://github.com/pallets/jinja) from 2.11.3 to 3.1.3. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/2.11.3...3.1.3) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- build_tools/benchmarks/reporting/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/benchmarks/reporting/requirements.txt b/build_tools/benchmarks/reporting/requirements.txt index 9dcb2d2452cd..cb7ee9904e47 100644 --- a/build_tools/benchmarks/reporting/requirements.txt +++ b/build_tools/benchmarks/reporting/requirements.txt @@ -1,2 +1,2 @@ pandas==1.5.0 -jinja2==2.11.3 \ No newline at end of file +jinja2==3.1.3 \ No newline at end of file